Source code for trieste.models.gpflow.sampler

# Copyright 2020 The Trieste Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module is the home of the sampling functionality required by Trieste's
GPflow wrappers.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Callable, Optional, Tuple, TypeVar, Union, cast

import tensorflow as tf
import tensorflow_probability as tfp
from check_shapes import check_shapes
from gpflow.kernels import Kernel, MultioutputKernel
from gpflux.layers.basis_functions.fourier_features import RandomFourierFeaturesCosine
from gpflux.math import compute_A_inv_b
from typing_extensions import Protocol, TypeGuard, runtime_checkable

from ...space import EncoderFunction
from ...types import TensorType
from ...utils import DEFAULTS, flatten_leading_dims
from ..interfaces import (
    ProbabilisticModel,
    ReparametrizationSampler,
    SupportsGetInducingVariables,
    SupportsGetInternalData,
    SupportsGetKernel,
    SupportsGetMeanFunction,
    SupportsGetObservationNoise,
    SupportsPredictJoint,
    TrajectoryFunction,
    TrajectoryFunctionClass,
    TrajectorySampler,
    get_encoder,
)

_IntTensorType = Union[tf.Tensor, int]


[docs] def qmc_normal_samples( num_samples: _IntTensorType, n_sample_dim: _IntTensorType, skip: _IntTensorType = 0, dtype: tf.DType = tf.float64, ) -> tf.Tensor: """ Generates `num_samples` sobol samples, skipping the first `skip`, where each sample has dimension `n_sample_dim`. """ if num_samples == 0 or n_sample_dim == 0: return tf.zeros(shape=(num_samples, n_sample_dim), dtype=dtype) sobol_samples = tf.math.sobol_sample( dim=n_sample_dim, num_results=num_samples, dtype=dtype, skip=skip, ) dist = tfp.distributions.Normal( loc=tf.constant(0.0, dtype=dtype), scale=tf.constant(1.0, dtype=dtype), ) normal_samples = dist.quantile(sobol_samples) return normal_samples
[docs] class IndependentReparametrizationSampler(ReparametrizationSampler[ProbabilisticModel]): r""" This sampler employs the *reparameterization trick* to approximate samples from a :class:`ProbabilisticModel`\ 's predictive distribution as .. math:: x \mapsto \mu(x) + \epsilon \sigma(x) where :math:`\epsilon \sim \mathcal N (0, 1)` is constant for a given sampler, thus ensuring samples form a continuous curve. """
[docs] skip: TensorType = tf.Variable(0, trainable=False)
"""Number of sobol sequence points to skip. This is incremented for each sampler.""" def __init__( self, sample_size: int, model: ProbabilisticModel, qmc: bool = False, qmc_skip: bool = True ): """ :param sample_size: The number of samples to take at each point. Must be positive. :param model: The model to sample from. :param qmc: Whether to use QMC sobol sampling instead of random normal sampling. QMC sampling more accurately approximates a normal distribution than truly random samples. :param qmc_skip: Whether to use the skip parameter to ensure the QMC sampler gives different samples whenever it is reset. This is not supported with XLA. :raise ValueError (or InvalidArgumentError): If ``sample_size`` is not positive. """ super().__init__(sample_size, model) self._eps: Optional[tf.Variable] = None self._qmc = qmc self._qmc_skip = qmc_skip @check_shapes( "at: [N..., 1, D] # IndependentReparametrizationSampler only supports batch sizes of one", "return: [N..., S, 1, L]", )
[docs] def sample(self, at: TensorType, *, jitter: float = DEFAULTS.JITTER) -> TensorType: """ Return approximate samples from the `model` specified at :meth:`__init__`. Multiple calls to :meth:`sample`, for any given :class:`IndependentReparametrizationSampler` and ``at``, will produce the exact same samples. Calls to :meth:`sample` on *different* :class:`IndependentReparametrizationSampler` instances will produce different samples. :param at: Where to sample the predictive distribution, with shape `[..., 1, D]`, for points of dimension `D`. :param jitter: The size of the jitter to use when stabilising the Cholesky decomposition of the covariance matrix. :return: The samples, of shape `[..., S, 1, L]`, where `S` is the `sample_size` and `L` is the number of latent model dimensions. :raise ValueError (or InvalidArgumentError): If ``at`` has an invalid shape or ``jitter`` is negative. """ tf.debugging.assert_greater_equal(jitter, 0.0) mean, var = self._model.predict(at[..., None, :, :]) # [..., 1, 1, L], [..., 1, 1, L] var = var + jitter def sample_eps() -> tf.Tensor: self._initialized.assign(True) if self._qmc: if self._qmc_skip: skip = IndependentReparametrizationSampler.skip IndependentReparametrizationSampler.skip.assign(skip + self._sample_size) else: skip = tf.constant(0) normal_samples = qmc_normal_samples( self._sample_size, mean.shape[-1], skip, dtype=var.dtype ) else: normal_samples = tf.random.normal( [self._sample_size, tf.shape(mean)[-1]], dtype=var.dtype ) return normal_samples # [S, L] if self._eps is None: self._eps = tf.Variable(sample_eps()) tf.cond( self._initialized, lambda: self._eps, lambda: self._eps.assign(sample_eps()), ) return mean + tf.sqrt(var) * self._eps[:, None, :] # [..., S, 1, L]
[docs] class BatchReparametrizationSampler(ReparametrizationSampler[SupportsPredictJoint]): r""" This sampler employs the *reparameterization trick* to approximate batches of samples from a :class:`ProbabilisticModel`\ 's predictive joint distribution as .. math:: x \mapsto \mu(x) + \epsilon L(x) where :math:`L` is the Cholesky factor s.t. :math:`LL^T` is the covariance, and :math:`\epsilon \sim \mathcal N (0, 1)` is constant for a given sampler, thus ensuring samples form a continuous curve. """
[docs] skip: TensorType = tf.Variable(0, trainable=False)
"""Number of sobol sequence points to skip. This is incremented for each sampler.""" def __init__( self, sample_size: int, model: SupportsPredictJoint, qmc: bool = False, qmc_skip: bool = True, ): """ :param sample_size: The number of samples for each batch of points. Must be positive. :param model: The model to sample from. :param qmc: Whether to use QMC sobol sampling instead of random normal sampling. QMC sampling more accurately approximates a normal distribution than truly random samples. :param qmc_skip: Whether to use the skip parameter to ensure the QMC sampler gives different samples whenever it is reset. This is not supported with XLA. :raise ValueError (or InvalidArgumentError): If ``sample_size`` is not positive. """ super().__init__(sample_size, model) if not isinstance(model, SupportsPredictJoint): raise NotImplementedError( f"BatchReparametrizationSampler only works with models that support " f"predict_joint; received {model!r}" ) self._eps: Optional[tf.Variable] = None self._qmc = qmc self._qmc_skip = qmc_skip
[docs] def sample(self, at: TensorType, *, jitter: float = DEFAULTS.JITTER) -> TensorType: """ Return approximate samples from the `model` specified at :meth:`__init__`. Multiple calls to :meth:`sample`, for any given :class:`BatchReparametrizationSampler` and ``at``, will produce the exact same samples. Calls to :meth:`sample` on *different* :class:`BatchReparametrizationSampler` instances will produce different samples. :param at: Batches of query points at which to sample the predictive distribution, with shape `[..., B, D]`, for batches of size `B` of points of dimension `D`. Must have a consistent batch size across all calls to :meth:`sample` for any given :class:`BatchReparametrizationSampler`. :param jitter: The size of the jitter to use when stabilising the Cholesky decomposition of the covariance matrix. :return: The samples, of shape `[..., S, B, L]`, where `S` is the `sample_size`, `B` the number of points per batch, and `L` the dimension of the model's predictive distribution. :raise ValueError (or InvalidArgumentError): If any of the following are true: - ``at`` is a scalar. - The batch size `B` of ``at`` is not positive. - The batch size `B` of ``at`` differs from that of previous calls. - ``jitter`` is negative. """ tf.debugging.assert_rank_at_least(at, 2) tf.debugging.assert_greater_equal(jitter, 0.0) batch_size = at.shape[-2] tf.debugging.assert_positive(batch_size) mean, cov = self._model.predict_joint(at) # [..., B, L], [..., L, B, B] def sample_eps() -> tf.Tensor: self._initialized.assign(True) if self._qmc: if self._qmc_skip: skip = IndependentReparametrizationSampler.skip IndependentReparametrizationSampler.skip.assign(skip + self._sample_size) else: skip = tf.constant(0) normal_samples = qmc_normal_samples( self._sample_size * mean.shape[-1], batch_size, skip, dtype=cov.dtype ) # [S*L, B] normal_samples = tf.reshape( normal_samples, (mean.shape[-1], self._sample_size, batch_size) ) # [L, S, B] normal_samples = tf.transpose(normal_samples, perm=[0, 2, 1]) # [L, B, S] else: normal_samples = tf.random.normal( [tf.shape(mean)[-1], batch_size, self._sample_size], dtype=cov.dtype ) # [L, B, S] return normal_samples if self._eps is None: # dynamically shaped as the same sampler may be called with different sized batches self._eps = tf.Variable(sample_eps(), shape=[None, None, self._sample_size]) tf.cond( self._initialized, lambda: self._eps, lambda: self._eps.assign(sample_eps()), ) if self._initialized: tf.debugging.assert_equal( batch_size, tf.shape(self._eps)[-2], f"{type(self).__name__} requires a fixed batch size. Got batch size {batch_size}" f" but previous batch size was {tf.shape(self._eps)[-2]}.", ) identity = tf.eye(batch_size, dtype=cov.dtype) # [B, B] cov_cholesky = tf.linalg.cholesky(cov + jitter * identity) # [..., L, B, B] variance_contribution = cov_cholesky @ self._eps # [..., L, B, S] leading_indices = tf.range(tf.rank(variance_contribution) - 3) absolute_trailing_indices = [-1, -2, -3] + tf.rank(variance_contribution) new_order = tf.concat([leading_indices, absolute_trailing_indices], axis=0) return mean[..., None, :, :] + tf.transpose(variance_contribution, new_order)
@runtime_checkable
[docs] class FeatureDecompositionInternalDataModel( SupportsGetKernel, SupportsGetMeanFunction, SupportsGetObservationNoise, SupportsGetInternalData, Protocol, ): """ A probabilistic model that supports get_kernel, get_mean_function, get_observation_noise and get_internal_data methods. """
@runtime_checkable
[docs] class FeatureDecompositionInducingPointModel( SupportsGetKernel, SupportsGetMeanFunction, SupportsGetInducingVariables, Protocol ): """ A probabilistic model that supports get_kernel, get_mean_function and get_inducing_point methods. """
FeatureDecompositionTrajectorySamplerModel = Union[ FeatureDecompositionInducingPointModel, FeatureDecompositionInternalDataModel, ] FeatureDecompositionTrajectorySamplerModelType = TypeVar( "FeatureDecompositionTrajectorySamplerModelType", bound=FeatureDecompositionTrajectorySamplerModel, contravariant=True, ) def _is_multioutput_kernel(kernel: Kernel) -> TypeGuard[MultioutputKernel]: return isinstance(kernel, MultioutputKernel) def _get_kernel_function(kernel: Kernel) -> Callable[[TensorType, TensorType], tf.Tensor]: # Select between a multioutput kernel and a single-output kernel. def K(X: TensorType, X2: Optional[TensorType] = None) -> tf.Tensor: if _is_multioutput_kernel(kernel): return kernel(X, X2, full_cov=True, full_output_cov=False) # [L, M, M] else: return tf.expand_dims(kernel(X, X2), axis=0) # [1, M, M] return K
[docs] class FeatureDecompositionTrajectorySampler( TrajectorySampler[FeatureDecompositionTrajectorySamplerModelType], ABC, ): r""" This is a general class to build functions that approximate a trajectory sampled from an underlying Gaussian process model. In particular, we approximate the Gaussian processes' posterior samples as the finite feature approximation .. math:: \hat{f}(x) = \sum_{i=1}^m \phi_i(x)\theta_i where :math:`\phi_i` are m features and :math:`\theta_i` are feature weights sampled from a given distribution Achieving consistency (ensuring that the same sample draw for all evalutions of a particular trajectory function) for exact sample draws from a GP is prohibitively costly because it scales cubically with the number of query points. However, finite feature representations can be evaluated with constant cost regardless of the required number of queries. """ def __init__( self, model: FeatureDecompositionTrajectorySamplerModelType, feature_functions: ResampleableRandomFourierFeatureFunctions, ): """ :param model: The model to sample from. :raise ValueError: If ``dataset`` is empty. """ super().__init__(model) self._feature_functions = feature_functions self._weight_sampler: Optional[Callable[[int], TensorType]] = None # lazy init self._mean_function = model.get_mean_function() def __repr__(self) -> str: """""" return f"""{self.__class__.__name__}( {self._model!r}, {self._feature_functions!r}) """
[docs] def get_trajectory(self) -> TrajectoryFunction: """ Generate an approximate function draw (trajectory) by sampling weights and evaluating the feature functions. :return: A trajectory function representing an approximate trajectory from the Gaussian process, taking an input of shape `[N, B, D]` and returning shape `[N, B, L]` where `L` is the number of outputs of the model. """ weight_sampler = self._prepare_weight_sampler() # prep feature weight distribution return feature_decomposition_trajectory( feature_functions=self._feature_functions, weight_sampler=weight_sampler, mean_function=self._mean_function, encoder=get_encoder(self._model), )
[docs] def update_trajectory(self, trajectory: TrajectoryFunction) -> TrajectoryFunction: """ Efficiently update a :const:`TrajectoryFunction` to reflect an update in its underlying :class:`ProbabilisticModel` and resample accordingly. For a :class:`FeatureDecompositionTrajectorySampler`, updating the sampler corresponds to resampling the feature functions (taking into account any changed kernel parameters) and recalculating the weight distribution. :param trajectory: The trajectory function to be resampled. :return: The new resampled trajectory function. """ tf.debugging.Assert( isinstance(trajectory, feature_decomposition_trajectory), [tf.constant([])] ) self._feature_functions.resample() # resample Fourier feature decomposition weight_sampler = self._prepare_weight_sampler() # recalculate weight distribution cast(feature_decomposition_trajectory, trajectory).update(weight_sampler=weight_sampler) return trajectory # return trajectory with updated features and weight distribution
[docs] def resample_trajectory(self, trajectory: TrajectoryFunction) -> TrajectoryFunction: """ Efficiently resample a :const:`TrajectoryFunction` in-place to avoid function retracing with every new sample. :param trajectory: The trajectory function to be resampled. :return: The new resampled trajectory function. """ tf.debugging.Assert( isinstance(trajectory, feature_decomposition_trajectory), [tf.constant([])] ) cast(feature_decomposition_trajectory, trajectory).resample() return trajectory # return trajectory with resampled weights
@abstractmethod
[docs] def _prepare_weight_sampler(self) -> Callable[[int], TensorType]: # [B] -> [B, F, L] """ Calculate the posterior of the feature weights for the specified feature functions, returning a function that takes in a batch size `B` and returns `B` samples for the weights of each of the `F` features for `L` outputs. """ raise NotImplementedError
[docs] class RandomFourierFeatureTrajectorySampler( FeatureDecompositionTrajectorySampler[FeatureDecompositionInternalDataModel] ): r""" This class builds functions that approximate a trajectory sampled from an underlying Gaussian process model. For tractibility, the Gaussian process is approximated with a Bayesian Linear model across a set of features sampled from the Fourier feature decomposition of the model's kernel. See :cite:`hernandez2014predictive` for details. Currently we do not support models with multiple latent Gaussian processes. In particular, we approximate the Gaussian processes' posterior samples as the finite feature approximation .. math:: \hat{f}(x) = \sum_{i=1}^m \phi_i(x)\theta_i where :math:`\phi_i` are m Fourier features and :math:`\theta_i` are feature weights sampled from a posterior distribution that depends on the feature values at the model's datapoints. Our implementation follows :cite:`hernandez2014predictive`, with our calculations differing slightly depending on properties of the problem. In particular, we used different calculation strategies depending on the number of considered features m and the number of data points n. If :math:`m<n` then we follow Appendix A of :cite:`hernandez2014predictive` and calculate the posterior distribution for :math:`\theta` following their Bayesian linear regression motivation, i.e. the computation revolves around an O(m^3) inversion of a design matrix. If :math:`n<m` then we use the kernel trick to recast computation to revolve around an O(n^3) inversion of a gram matrix. As well as being more efficient in early BO steps (where :math:`n<m`), this second computation method allows much larger choices of m (as required to approximate very flexible kernels). """ def __init__( self, model: FeatureDecompositionInternalDataModel, num_features: int = 1000, ): """ :param model: The model to sample from. :param num_features: The number of features used to approximate the kernel. We use a default of 1000 as it typically perfoms well for a wide range of kernels. Note that very smooth kernels (e.g. RBF) can be well-approximated with fewer features. :raise ValueError: If ``dataset`` is empty. """ if not isinstance(model, FeatureDecompositionInternalDataModel): raise NotImplementedError( f"RandomFourierFeatureTrajectorySampler only works with models with " f"get_kernel, get_observation_noise and get_internal_data methods; " f"but received {model!r}." ) tf.debugging.assert_positive(num_features) self._num_features = num_features feature_functions = ResampleableRandomFourierFeatureFunctions(model, self._num_features) super().__init__(model, feature_functions)
[docs] def _prepare_weight_sampler(self) -> Callable[[int], TensorType]: # [B] -> [B, F, 1] """ Calculate the posterior of theta (the feature weights) for the RFFs, returning a function that takes in a batch size `B` and returns `B` samples for the weights of each of the RFF `F` features for one output. """ dataset = self._model.get_internal_data() num_data = tf.shape(dataset.query_points)[0] # n if ( self._num_features < num_data ): # if m < n then calculate posterior in design space (an m*m matrix inversion) theta_posterior = self._prepare_theta_posterior_in_design_space() else: # if n <= m then calculate posterior in gram space (an n*n matrix inversion) theta_posterior = self._prepare_theta_posterior_in_gram_space() return lambda b: tf.expand_dims(theta_posterior.sample(b), axis=-1)
[docs] def _prepare_theta_posterior_in_design_space(self) -> tfp.distributions.MultivariateNormalTriL: r""" Calculate the posterior of theta (the feature weights) in the design space. This distribution is a Gaussian .. math:: \theta \sim N(D^{-1}\Phi^Ty,D^{-1}\sigma^2) where the [m,m] design matrix :math:`D=(\Phi^T\Phi + \sigma^2I_m)` is defined for the [n,m] matrix of feature evaluations across the training data :math:`\Phi` and observation noise variance :math:`\sigma^2`. """ dataset = self._model.get_internal_data() phi = self._feature_functions(tf.convert_to_tensor(dataset.query_points)) # [n, m] D = tf.matmul(phi, phi, transpose_a=True) # [m, m] s = self._model.get_observation_noise() * tf.eye(self._num_features, dtype=phi.dtype) L = tf.linalg.cholesky(D + s) D_inv = tf.linalg.cholesky_solve(L, tf.eye(self._num_features, dtype=phi.dtype)) residuals = dataset.observations - self._model.get_mean_function()(dataset.query_points) theta_posterior_mean = tf.matmul(D_inv, tf.matmul(phi, residuals, transpose_a=True))[ :, 0 ] # [m,] theta_posterior_chol_covariance = tf.linalg.cholesky( D_inv * self._model.get_observation_noise() ) # [m, m] return tfp.distributions.MultivariateNormalTriL( theta_posterior_mean, theta_posterior_chol_covariance )
[docs] def _prepare_theta_posterior_in_gram_space(self) -> tfp.distributions.MultivariateNormalTriL: r""" Calculate the posterior of theta (the feature weights) in the gram space. .. math:: \theta \sim N(\Phi^TG^{-1}y,I_m - \Phi^TG^{-1}\Phi) where the [n,n] gram matrix :math:`G=(\Phi\Phi^T + \sigma^2I_n)` is defined for the [n,m] matrix of feature evaluations across the training data :math:`\Phi` and observation noise variance :math:`\sigma^2`. """ dataset = self._model.get_internal_data() num_data = tf.shape(dataset.query_points)[0] # n phi = self._feature_functions(tf.convert_to_tensor(dataset.query_points)) # [n, m] G = tf.matmul(phi, phi, transpose_b=True) # [n, n] s = self._model.get_observation_noise() * tf.eye(num_data, dtype=phi.dtype) L = tf.linalg.cholesky(G + s) L_inv_phi = tf.linalg.triangular_solve(L, phi) # [n, m] residuals = dataset.observations - self._model.get_mean_function()( dataset.query_points ) # [n, 1] L_inv_y = tf.linalg.triangular_solve(L, residuals) # [n, 1] theta_posterior_mean = tf.tensordot(tf.transpose(L_inv_phi), L_inv_y, [[-1], [-2]])[ :, 0 ] # [m,] theta_posterior_covariance = tf.eye(self._num_features, dtype=phi.dtype) - tf.tensordot( tf.transpose(L_inv_phi), L_inv_phi, [[-1], [-2]] ) # [m, m] theta_posterior_chol_covariance = tf.linalg.cholesky(theta_posterior_covariance) # [m, m] return tfp.distributions.MultivariateNormalTriL( theta_posterior_mean, theta_posterior_chol_covariance )
[docs] class DecoupledTrajectorySampler( FeatureDecompositionTrajectorySampler[ Union[ FeatureDecompositionInducingPointModel, FeatureDecompositionInternalDataModel, ] ] ): r""" This class builds functions that approximate a trajectory sampled from an underlying Gaussian process model using decoupled sampling. See :cite:`wilson2020efficiently` for an introduction to decoupled sampling. Unlike our :class:`RandomFourierFeatureTrajectorySampler` which uses a RFF decomposition to aprroximate the Gaussian process posterior, a :class:`DecoupledTrajectorySampler` only uses an RFF decomposition to approximate the Gausian process prior and instead using a canonical decomposition to discretize the effect of updating the prior on the given data. In particular, we approximate the Gaussian processes' posterior samples as the finite feature approximation .. math:: \hat{f}(.) = \sum_{i=1}^L w_i\phi_i(.) + \sum_{j=1}^m v_jk(.,z_j) where :math:`\phi_i(.)` and :math:`w_i` are the Fourier features and their weights that discretize the prior. In contrast, `k(.,z_j)` and :math:`v_i` are the canonical features and their weights that discretize the data update. The expression for :math:`v_i` depends on if we are using an exact Gaussian process or a sparse approximations. See eq. (13) in :cite:`wilson2020efficiently` for details. Note that if a model is both of :class:`FeatureDecompositionInducingPointModel` type and :class:`FeatureDecompositionInternalDataModel` type, :class:`FeatureDecompositionInducingPointModel` will take a priority and inducing points will be used for computations rather than data. """ def __init__( self, model: Union[ FeatureDecompositionInducingPointModel, FeatureDecompositionInternalDataModel, ], num_features: int = 1000, ): """ :param model: The model to sample from. :param num_features: The number of features used to approximate the kernel. We use a default of 1000 as it typically perfoms well for a wide range of kernels. Note that very smooth kernels (e.g. RBF) can be well-approximated with fewer features. :raise NotImplementedError: If the model is not of valid type. """ if not isinstance( model, (FeatureDecompositionInducingPointModel, FeatureDecompositionInternalDataModel) ): raise NotImplementedError( f"DecoupledTrajectorySampler only works with models that either support " f"get_kernel, get_observation_noise and get_internal_data or support get_kernel " f"and get_inducing_variables; but received {model!r}." ) tf.debugging.assert_positive(num_features) self._num_features = num_features feature_functions = ResampleableDecoupledFeatureFunctions(model, self._num_features) super().__init__(model, feature_functions)
[docs] def _prepare_weight_sampler(self) -> Callable[[int], TensorType]: # [B] -> [B, F + M, L] """ Prepare the sampler function that provides samples of the feature weights for both the RFF and canonical feature functions, i.e. we return a function that takes in a batch size `B` and returns `B` samples for the weights of each of the `F` RFF features and `M` canonical features for `L` outputs. """ kernel_K = _get_kernel_function(self._model.get_kernel()) if isinstance(self._model, FeatureDecompositionInducingPointModel): ( # extract variational parameters inducing_points, q_mu, q_sqrt, whiten, ) = self._model.get_inducing_variables() # [M, D], [M, L], [L, M, M], [] Kmm = kernel_K(inducing_points, inducing_points) # [L, M, M] Kmm += tf.eye(tf.shape(inducing_points)[0], dtype=Kmm.dtype) * DEFAULTS.JITTER else: # massage quantities from GP to look like variational parameters internal_data = self._model.get_internal_data() inducing_points = internal_data.query_points # [M, D] q_mu = self._model.get_internal_data().observations # [M, L] q_mu = q_mu - self._model.get_mean_function()( inducing_points ) # account for mean function q_sqrt = tf.eye(tf.shape(inducing_points)[0], dtype=tf.float64) # [M, M] q_sqrt = tf.expand_dims(q_sqrt, axis=0) # [1, M, M] q_sqrt = tf.math.sqrt(self._model.get_observation_noise()) * q_sqrt whiten = False Kmm = kernel_K(inducing_points, inducing_points) + q_sqrt**2 # [L, M, M] M, L = tf.shape(q_mu) tf.debugging.assert_shapes( [ (inducing_points, ["M", "D"]), (q_mu, ["M", "L"]), (q_sqrt, ["L", "M", "M"]), (Kmm, ["L", "M", "M"]), ] ) def weight_sampler(batch_size: int) -> Tuple[TensorType, TensorType]: prior_weights = tf.random.normal( # Non-RFF features will require scaling here [L, self._num_features, batch_size], dtype=tf.float64 ) # [L, F, B] u_noise_sample = tf.matmul( q_sqrt, # [L, M, M] tf.random.normal((L, M, batch_size), dtype=tf.float64), # [L, M, B] ) # [L, M, B] u_sample = tf.linalg.matrix_transpose(q_mu)[..., None] + u_noise_sample # [L, M, B] if whiten: Luu = tf.linalg.cholesky(Kmm) # [L, M, M] u_sample = tf.matmul(Luu, u_sample) # [L, M, B] # It is important that the feature-function is called with a tensor, instead of a # parameter (which inducing points can be). This is to ensure pickling works correctly. # First time a Keras layer (i.e. feature-functions) is built, the shape of the input is # used to set the input-spec. If the input is a parameter, the input-spec will not be # for an ordinary tensor and pickling will fail. phi_Z = self._feature_functions(tf.convert_to_tensor(inducing_points))[ ..., : self._num_features ] # [M, F] or [L, M, F] weight_space_prior_Z = phi_Z @ prior_weights # [L, M, B] diff = u_sample - weight_space_prior_Z # [L, M, B] v = compute_A_inv_b(Kmm, diff) # [L, M, B] tf.debugging.assert_shapes([(v, ["L", "M", "B"]), (prior_weights, ["L", "F", "B"])]) return tf.transpose( tf.concat([prior_weights, v], axis=1), perm=[2, 1, 0] ) # [B, F + M, L] return weight_sampler
[docs] class ResampleableRandomFourierFeatureFunctions(RandomFourierFeaturesCosine): """ A wrapper around GPFlux's random Fourier feature function that allows for efficient in-place updating when generating new decompositions. In particular, the bias and weights are stored as variables, which can then be updated by calling :meth:`resample` without triggering expensive graph retracing. Note that if a model is both of :class:`FeatureDecompositionInducingPointModel` type and :class:`FeatureDecompositionInternalDataModel` type, :class:`FeatureDecompositionInducingPointModel` will take a priority and inducing points will be used for computations rather than data. """ def __init__( self, model: Union[ FeatureDecompositionInducingPointModel, FeatureDecompositionInternalDataModel, ], n_components: int, ): """ :param model: The model that will be approximed by these feature functions. :param n_components: The desired number of features. :raise NotImplementedError: If the model is not of valid type. """ if not isinstance( model, ( FeatureDecompositionInducingPointModel, FeatureDecompositionInternalDataModel, ), ): raise NotImplementedError( f"ResampleableRandomFourierFeatureFunctions only work with models that either" f"support get_kernel, get_observation_noise and get_internal_data or support " f"get_kernel and get_inducing_variables;" f"but received {model!r}." ) super().__init__(model.get_kernel(), n_components, dtype=tf.float64) if isinstance(model, SupportsGetInducingVariables): dummy_X = model.get_inducing_variables()[0][0:1, :] else: dummy_X = model.get_internal_data().query_points[0:1, :] dummy_X = self.kernel.slice(dummy_X, None)[0] # Keep only the active dims from the kernel. # Always build the weights and biases. This is important for saving the trajectory (using # tf.saved_model.save) before it has been used. self.build(dummy_X.shape)
[docs] def resample(self) -> None: """ Resample weights and biases """ self.b.assign(self._bias_init(tf.shape(self.b), dtype=self._dtype)) self.W.assign(self._weights_init(tf.shape(self.W), dtype=self._dtype))
[docs] def call(self, inputs: TensorType) -> TensorType: # [N, D] -> [N, F] or [L, N, F] """ Evaluate the basis functions at ``inputs`` """ inputs = self.kernel.slice(inputs, None)[0] # Keep only active dims from the kernel return super().call(inputs) # [N, F] or [L, N, F]
[docs] class ResampleableDecoupledFeatureFunctions(ResampleableRandomFourierFeatureFunctions): """ A wrapper around our :class:`ResampleableRandomFourierFeatureFunctions` which rather than evaluates just `F` RFF functions instead evaluates the concatenation of `F` RFF functions with evaluations of the canonical basis functions. Note that if a model is both of :class:`FeatureDecompositionInducingPointModel` type and :class:`FeatureDecompositionInternalDataModel` type, :class:`FeatureDecompositionInducingPointModel` will take a priority and inducing points will be used for computations rather than data. """ def __init__( self, model: Union[ FeatureDecompositionInducingPointModel, FeatureDecompositionInternalDataModel, ], n_components: int, ): """ :param model: The model that will be approximed by these feature functions. :param n_components: The desired number of features. """ super().__init__(model, n_components) if isinstance(model, SupportsGetInducingVariables): self._inducing_points = model.get_inducing_variables()[0] # [M, D] else: self._inducing_points = model.get_internal_data().query_points # [M, D] kernel_K = _get_kernel_function(self.kernel) self._canonical_feature_functions = lambda x: tf.linalg.matrix_transpose( kernel_K(self._inducing_points, x) )
[docs] def call(self, inputs: TensorType) -> TensorType: # [N, D] -> [N, F + M] or [L, N, F + M] """ combine prior basis functions with canonical basis functions """ fourier_feature_eval = super().call(inputs) # [N, F] or [L, N, F] canonical_feature_eval = self._canonical_feature_functions(inputs) # [1, N, M] or [L, N, M] # ensure matching rank between features, i.e. drop the leading 1 dimension matched_shape = tf.shape(canonical_feature_eval)[-tf.rank(fourier_feature_eval) :] canonical_feature_eval = tf.reshape(canonical_feature_eval, matched_shape) return tf.concat([fourier_feature_eval, canonical_feature_eval], axis=-1)
[docs] class feature_decomposition_trajectory(TrajectoryFunctionClass): r""" An approximate sample from a Gaussian processes' posterior samples represented as a finite weighted sum of features. A trajectory is given by .. math:: \hat{f}(x) = \sum_{i=1}^m \phi_i(x)\theta_i where :math:`\phi_i` are m feature functions and :math:`\theta_i` are feature weights sampled from a posterior distribution. The number of trajectories (i.e. batch size) is determined from the first call of the trajectory. In order to change the batch size, a new :class:`TrajectoryFunction` must be built. """ def __init__( self, feature_functions: Callable[[TensorType], TensorType], weight_sampler: Callable[[int], TensorType], mean_function: Callable[[TensorType], TensorType], encoder: EncoderFunction | None = None, ): """ :param feature_functions: Set of feature function. :param weight_sampler: New sampler that generates feature weight samples. :param mean_function: The underlying model's mean function. :param encoder: Optional encoder with which to transform input points. """ self._feature_functions = feature_functions self._mean_function = mean_function self._weight_sampler = weight_sampler self._encoder = encoder self._initialized = tf.Variable(False) self._weights_sample = tf.Variable( # dummy init to be updated before trajectory evaluation tf.ones([0, 0, 0], dtype=tf.float64), shape=[None, None, None] ) self._batch_size = tf.Variable( 0, dtype=tf.int32 ) # dummy init to be updated before trajectory evaluation @tf.function
[docs] def __call__(self, inputs: TensorType) -> TensorType: # [N, B, D] -> [N, B, L] """Call trajectory function.""" if self._encoder is not None: inputs = self._encoder(inputs) if not self._initialized: # work out desired batch size from input self._batch_size.assign(tf.shape(inputs)[-2]) # B self.resample() # sample B feature weights self._initialized.assign(True) tf.debugging.assert_equal( tf.shape(inputs)[-2], self._batch_size.value(), message=f""" This trajectory only supports batch sizes of {self._batch_size}. If you wish to change the batch size you must get a new trajectory by calling the get_trajectory method of the trajectory sampler. """, ) flat_inputs, unflatten = flatten_leading_dims(inputs) # [N*B, D] flattened_feature_evaluations = self._feature_functions( flat_inputs ) # [N*B, F + M] or [L, N*B, F + M] # ensure tensor is always rank 3 rank3_shape = tf.concat([[1], tf.shape(flattened_feature_evaluations)], axis=0)[-3:] flattened_feature_evaluations = tf.reshape(flattened_feature_evaluations, rank3_shape) flattened_feature_evaluations = tf.transpose( flattened_feature_evaluations, perm=[1, 2, 0] ) # [N*B, F + M, L] feature_evaluations = unflatten(flattened_feature_evaluations) # [N, B, F + M, L] mean = self._mean_function(inputs) # account for the model's mean function return tf.reduce_sum(feature_evaluations * self._weights_sample, -2) + mean # [N, B, L]
[docs] def resample(self) -> None: """ Efficiently resample in-place without retracing. """ self._weights_sample.assign( # [B, F + M, L] self._weight_sampler(self._batch_size) ) # resample weights
[docs] def update(self, weight_sampler: Callable[[int], TensorType]) -> None: """ Efficiently update the trajectory with a new weight distribution and resample its weights. :param weight_sampler: New sampler that generates feature weight samples. """ self._weight_sampler = weight_sampler # update weight sampler self.resample() # resample weights