Source code for gpflux.runtime_checks

#
# Copyright (c) 2021 The GPflux Contributors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
""" Runtime checks """
from typing import Optional, Tuple

from gpflow.inducing_variables import (
    FallbackSeparateIndependentInducingVariables,
    MultioutputInducingVariables,
)
from gpflow.kernels import MultioutputKernel
from gpflow.mean_functions import MeanFunction

from gpflux.exceptions import GPLayerIncompatibilityException


[docs]def verify_compatibility( kernel: MultioutputKernel, mean_function: MeanFunction, inducing_variable: MultioutputInducingVariables, ) -> Tuple[int, int]: """ Checks that the arguments are all compatible with each other for use in a `GPLayer`. :param kernel: The multioutput kernel for the layer. :param inducing_variable: The inducing features for the layer. :param mean_function: The mean function applied to the inputs. :raises GPLayerIncompatibilityException: If an incompatibility is detected. :returns: number of inducing variables and number of latent GPs """ # TODO: This function could be simplified by upstream enhancements to # GPflow: e.g. by adding an `output_dim` attribute to # MultioutputInducingVariable subclasses if not isinstance(inducing_variable, MultioutputInducingVariables): raise GPLayerIncompatibilityException( "`inducing_variable` must be a `gpflow.inducing_variables.MultioutputInducingVariables`" ) if not isinstance(kernel, MultioutputKernel): raise GPLayerIncompatibilityException( "`kernel` must be a `gpflow.kernels.MultioutputKernel`" ) if not isinstance(mean_function, MeanFunction): raise GPLayerIncompatibilityException( "`kernel` must be a `gpflow.mean_functions.MeanFunction`" ) latent_inducing_points: Optional[int] = None if isinstance(inducing_variable, FallbackSeparateIndependentInducingVariables): latent_inducing_points = len(inducing_variable.inducing_variable_list) num_latent_gps = kernel.num_latent_gps if latent_inducing_points is not None: if latent_inducing_points != num_latent_gps: raise GPLayerIncompatibilityException( f"The number of latent GPs ({num_latent_gps}) does not match " f"the number of separate independent inducing_variables ({latent_inducing_points})" ) num_inducing_points = inducing_variable.num_inducing # currently the same for each dim return num_inducing_points, num_latent_gps