Source code for trieste.acquisition.function.greedy_batch

# Copyright 2021 The Trieste Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains local penalization-based acquisition function builders.
"""
from __future__ import annotations

from typing import Callable, Dict, Mapping, Optional, Union, cast

import gpflow
import tensorflow as tf
import tensorflow_probability as tfp
from check_shapes import check_shapes
from typing_extensions import Protocol, runtime_checkable

from ...data import Dataset
from ...models import FastUpdateModel, ModelStack, ProbabilisticModel
from ...models.interfaces import (
    PredictJointModelStack,
    PredictJointPredictYModelStack,
    PredictYModelStack,
    SupportsGetKernel,
    SupportsGetObservationNoise,
    SupportsPredictJoint,
    SupportsPredictY,
)
from ...observer import OBJECTIVE
from ...space import SearchSpace
from ...types import Tag, TensorType
from ..interface import (
    AcquisitionFunction,
    AcquisitionFunctionBuilder,
    GreedyAcquisitionFunctionBuilder,
    PenalizationFunction,
    SingleModelAcquisitionBuilder,
    SingleModelGreedyAcquisitionBuilder,
    UpdatablePenalizationFunction,
)
from .entropy import MinValueEntropySearch
from .function import ExpectedImprovement, MakePositive, expected_improvement


[docs] class LocalPenalization(SingleModelGreedyAcquisitionBuilder[ProbabilisticModel]): r""" Builder of the acquisition function maker for greedily collecting batches by local penalization. The resulting :const:`AcquisitionFunctionMaker` takes in a set of pending points and returns a base acquisition function penalized around those points. An estimate of the objective function's Lipschitz constant is used to control the size of penalization. Local penalization allows us to perform batch Bayesian optimization with a standard (non-batch) acquisition function. All that we require is that the acquisition function takes strictly positive values. By iteratively building a batch of points though sequentially maximizing this acquisition function but down-weighted around locations close to the already chosen (pending) points, local penalization provides diverse batches of candidate points. Local penalization is applied to the acquisition function multiplicatively. However, to improve numerical stability, we perform additive penalization in a log space. The Lipschitz constant and additional penalization parameters are estimated once when first preparing the acquisition function with no pending points. These estimates are reused for all subsequent function calls. """ def __init__( self, search_space: SearchSpace, num_samples: int = 500, penalizer: Optional[ Callable[ [ProbabilisticModel, TensorType, TensorType, TensorType], Union[PenalizationFunction, UpdatablePenalizationFunction], ] ] = None, base_acquisition_function_builder: ( ExpectedImprovement | MinValueEntropySearch[ProbabilisticModel] | MakePositive[ProbabilisticModel] | None ) = None, ): """ :param search_space: The global search space over which the optimisation is defined. :param num_samples: Size of the random sample over which the Lipschitz constant is estimated. We recommend scaling this with search space dimension. :param penalizer: The chosen penalization method (defaults to soft penalization). This should be a function that accepts a model, pending points, lipschitz constant and eta and returns a PenalizationFunction. :param base_acquisition_function_builder: Base acquisition function to be penalized (defaults to expected improvement). Local penalization only supports strictly positive acquisition functions. :raise tf.errors.InvalidArgumentError: If ``num_samples`` is not positive. """ tf.debugging.assert_positive(num_samples) self._search_space = search_space self._num_samples = num_samples self._lipschitz_penalizer = soft_local_penalizer if penalizer is None else penalizer if base_acquisition_function_builder is None: self._base_builder: SingleModelAcquisitionBuilder[ProbabilisticModel] = ( ExpectedImprovement() ) else: self._base_builder = base_acquisition_function_builder self._lipschitz_constant = None self._eta = None self._base_acquisition_function: Optional[AcquisitionFunction] = None self._penalization: Optional[PenalizationFunction | UpdatablePenalizationFunction] = None self._penalized_acquisition: Optional[AcquisitionFunction] = None
[docs] def prepare_acquisition_function( self, model: ProbabilisticModel, dataset: Optional[Dataset] = None, pending_points: Optional[TensorType] = None, ) -> AcquisitionFunction: """ :param model: The model. :param dataset: The data from the observer. Must be populated. :param pending_points: The points we penalize with respect to. :return: The (log) expected improvement penalized with respect to the pending points. :raise tf.errors.InvalidArgumentError: If the ``dataset`` is empty. """ tf.debugging.Assert(dataset is not None, [tf.constant([])]) dataset = cast(Dataset, dataset) tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.") acq = self._update_base_acquisition_function(dataset, model) if pending_points is not None and len(pending_points) != 0: acq = self._update_penalization(acq, dataset, model, pending_points) return acq
[docs] def update_acquisition_function( self, function: AcquisitionFunction, model: ProbabilisticModel, dataset: Optional[Dataset] = None, pending_points: Optional[TensorType] = None, new_optimization_step: bool = True, ) -> AcquisitionFunction: """ :param function: The acquisition function to update. :param model: The model. :param dataset: The data from the observer. Must be populated. :param pending_points: Points already chosen to be in the current batch (of shape [M,D]), where M is the number of pending points and D is the search space dimension. :param new_optimization_step: Indicates whether this call to update_acquisition_function is to start of a new optimization step, of to continue collecting batch of points for the current step. Defaults to ``True``. :return: The updated acquisition function. """ tf.debugging.Assert(dataset is not None, [tf.constant([])]) dataset = cast(Dataset, dataset) tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.") tf.debugging.Assert(self._base_acquisition_function is not None, [tf.constant([])]) if new_optimization_step: self._update_base_acquisition_function(dataset, model) if pending_points is None or len(pending_points) == 0: # no penalization required if no pending_points return cast(AcquisitionFunction, self._base_acquisition_function) return self._update_penalization(function, dataset, model, pending_points)
def _update_penalization( self, function: Optional[AcquisitionFunction], dataset: Dataset, model: ProbabilisticModel, pending_points: Optional[TensorType] = None, ) -> AcquisitionFunction: tf.debugging.assert_rank(pending_points, 2) if self._penalized_acquisition is not None and isinstance( self._penalization, UpdatablePenalizationFunction ): # if possible, just update the penalization function variables self._penalization.update(pending_points, self._lipschitz_constant, self._eta) return self._penalized_acquisition else: # otherwise construct a new penalized acquisition function self._penalization = self._lipschitz_penalizer( model, pending_points, self._lipschitz_constant, self._eta ) self._penalized_acquisition = PenalizedAcquisition( cast(PenalizedAcquisition, self._base_acquisition_function), self._penalization ) return self._penalized_acquisition @tf.function(experimental_relax_shapes=True) def _get_lipschitz_estimate( self, model: ProbabilisticModel, sampled_points: TensorType ) -> tuple[TensorType, TensorType]: with tf.GradientTape() as g: g.watch(sampled_points) mean, _ = model.predict(sampled_points) grads = g.gradient(mean, sampled_points) grads_norm = tf.norm(grads, axis=1) max_grads_norm = tf.reduce_max(grads_norm) eta = tf.reduce_min(mean, axis=0) return max_grads_norm, eta def _update_base_acquisition_function( self, dataset: Dataset, model: ProbabilisticModel ) -> AcquisitionFunction: samples = self._search_space.sample(num_samples=self._num_samples) samples = tf.concat([dataset.query_points, samples], 0) lipschitz_constant, eta = self._get_lipschitz_estimate(model, samples) if lipschitz_constant < 1e-5: # threshold to improve numerical stability for 'flat' models lipschitz_constant = 10 self._lipschitz_constant = lipschitz_constant self._eta = eta if self._base_acquisition_function is not None: self._base_acquisition_function = self._base_builder.update_acquisition_function( self._base_acquisition_function, model, dataset=dataset, ) elif isinstance(self._base_builder, ExpectedImprovement): # reuse eta estimate self._base_acquisition_function = cast( AcquisitionFunction, expected_improvement(model, self._eta) ) else: self._base_acquisition_function = self._base_builder.prepare_acquisition_function( model, dataset=dataset, ) return self._base_acquisition_function
[docs] class PenalizedAcquisition: """Class representing a penalized acquisition function.""" # (note that this needs to be defined as a top level class make it pickleable) def __init__( self, base_acquisition_function: AcquisitionFunction, penalization: PenalizationFunction ): """ :param base_acquisition_function: Base (unpenalized) acquisition function. :param penalization: Penalization function. """ self._base_acquisition_function = base_acquisition_function self._penalization = penalization @tf.function def __call__(self, x: TensorType) -> TensorType: log_acq = tf.math.log(self._base_acquisition_function(x)) + tf.math.log( self._penalization(x) ) return tf.math.exp(log_acq)
[docs] class local_penalizer(UpdatablePenalizationFunction): def __init__( self, model: ProbabilisticModel, pending_points: TensorType, lipschitz_constant: TensorType, eta: TensorType, ): """Initialize the local penalizer. :param model: The model over the specified ``dataset``. :param pending_points: The points we penalize with respect to. :param lipschitz_constant: The estimated Lipschitz constant of the objective function. :param eta: The estimated global minima. :return: The local penalization function. This function will raise :exc:`ValueError` or :exc:`~tf.errors.InvalidArgumentError` if used with a batch size greater than one.""" self._model = model mean_pending, variance_pending = model.predict(pending_points) self._pending_points = tf.Variable(pending_points, shape=[None, *pending_points.shape[1:]]) self._radius = tf.Variable( tf.transpose((mean_pending - eta) / lipschitz_constant), shape=[1, None], ) self._scale = tf.Variable( tf.transpose(tf.sqrt(variance_pending) / lipschitz_constant), shape=[1, None], )
[docs] def update( self, pending_points: TensorType, lipschitz_constant: TensorType, eta: TensorType, ) -> None: """Update the local penalizer with new variable values.""" mean_pending, variance_pending = self._model.predict(pending_points) self._pending_points.assign(pending_points) self._radius.assign(tf.transpose((mean_pending - eta) / lipschitz_constant)) self._scale.assign(tf.transpose(tf.sqrt(variance_pending) / lipschitz_constant))
[docs] class soft_local_penalizer(local_penalizer): r""" Return the soft local penalization function used for single-objective greedy batch Bayesian optimization in :cite:`Gonzalez:2016`. Soft penalization returns the probability that a candidate point does not belong in the exclusion zones of the pending points. For model posterior mean :math:`\mu`, model posterior variance :math:`\sigma^2`, current "best" function value :math:`\eta`, and an estimated Lipschitz constant :math:`L`,the penalization from a set of pending point :math:`x'` on a candidate point :math:`x` is given by .. math:: \phi(x, x') = \frac{1}{2}\textrm{erfc}(-z) where :math:`z = \frac{1}{\sqrt{2\sigma^2(x')}}(L||x'-x|| + \eta - \mu(x'))`. The penalization from a set of pending points is just product of the individual penalizations. See :cite:`Gonzalez:2016` for a full derivation. :param model: The model over the specified ``dataset``. :param pending_points: The points we penalize with respect to. :param lipschitz_constant: The estimated Lipschitz constant of the objective function. :param eta: The estimated global minima. :return: The local penalization function. This function will raise :exc:`ValueError` or :exc:`~tf.errors.InvalidArgumentError` if used with a batch size greater than one. """ @tf.function
[docs] def __call__(self, x: TensorType) -> TensorType: tf.debugging.assert_shapes( [(x, [..., 1, None])], message="This penalization function cannot be calculated for batches of points.", ) pairwise_distances = tf.norm( tf.expand_dims(x, 1) - tf.expand_dims(self._pending_points, 0), axis=-1 ) standardised_distances = (pairwise_distances - self._radius) / self._scale normal = tfp.distributions.Normal(tf.constant(0, x.dtype), tf.constant(1, x.dtype)) penalization = normal.cdf(standardised_distances) return tf.reduce_prod(penalization, axis=-1)
[docs] class hard_local_penalizer(local_penalizer): r""" Return the hard local penalization function used for single-objective greedy batch Bayesian optimization in :cite:`Alvi:2019`. Hard penalization is a stronger penalizer than soft penalization and is sometimes more effective See :cite:`Alvi:2019` for details. Our implementation follows theirs, with the penalization from a set of pending points being the product of the individual penalizations. :param model: The model over the specified ``dataset``. :param pending_points: The points we penalize with respect to. :param lipschitz_constant: The estimated Lipschitz constant of the objective function. :param eta: The estimated global minima. :return: The local penalization function. This function will raise :exc:`ValueError` or :exc:`~tf.errors.InvalidArgumentError` if used with a batch size greater than one. """ @tf.function
[docs] def __call__(self, x: TensorType) -> TensorType: tf.debugging.assert_shapes( [(x, [..., 1, None])], message="This penalization function cannot be calculated for batches of points.", ) pairwise_distances = tf.norm( tf.expand_dims(x, 1) - tf.expand_dims(self._pending_points, 0), axis=-1 ) p = -5 # following experiments of :cite:`Alvi:2019`. penalization = ((pairwise_distances / (self._radius + self._scale)) ** p + 1) ** (1 / p) return tf.reduce_prod(penalization, axis=-1)
@runtime_checkable
[docs] class FantasizerModelType( FastUpdateModel, SupportsPredictJoint, SupportsPredictY, SupportsGetKernel, SupportsGetObservationNoise, Protocol, ): """The model requirements for the Fantasizer acquisition function."""
[docs] class FantasizerModelStack( PredictJointModelStack, PredictYModelStack, ModelStack[FantasizerModelType] ): """ A stack of models :class:`FantasizerModelType` models. Note that this delegates predict_joint and predict_y but none of the other methods. """
FantasizerModelOrStack = Union[FantasizerModelType, FantasizerModelStack]
[docs] class Fantasizer(GreedyAcquisitionFunctionBuilder[FantasizerModelOrStack]): r""" Builder of the acquisition function maker for greedily collecting batches. Fantasizer allows us to perform batch Bayesian optimization with any standard (non-batch) acquisition function. Here, every time a query point is chosen by maximising an acquisition function, its corresponding observation is "fantasized", and the models are conditioned further on this new artificial data. This implies that the models need to predict what their updated predictions would be given new data, see :class:`~FastUpdateModel`. These equations are for instance in closed form for the GPR model, see :cite:`chevalier2014corrected` (eqs. 8-10) for details. There are several ways to "fantasize" data: the "kriging believer" heuristic (KB, see :cite:`ginsbourger2010kriging`) uses the mean of the model as observations. "sample" uses samples from the model. """ def __init__( self, base_acquisition_function_builder: Optional[ AcquisitionFunctionBuilder[SupportsPredictJoint] | SingleModelAcquisitionBuilder[SupportsPredictJoint] ] = None, fantasize_method: str = "KB", ): """ :param base_acquisition_function_builder: The acquisition function builder to use. Defaults to :class:`~trieste.acquisition.ExpectedImprovement`. :param fantasize_method: The following options are available: "KB" and "sample". See class docs for more details. :raise tf.errors.InvalidArgumentError: If ``fantasize_method`` is not "KB" or "sample". """ tf.debugging.Assert(fantasize_method in ["KB", "sample"], [tf.constant([])]) if base_acquisition_function_builder is None: base_acquisition_function_builder = ExpectedImprovement() if isinstance(base_acquisition_function_builder, SingleModelAcquisitionBuilder): base_acquisition_function_builder = base_acquisition_function_builder.using(OBJECTIVE) self._builder = base_acquisition_function_builder self._fantasize_method = fantasize_method self._base_acquisition_function: Optional[AcquisitionFunction] = None self._fantasized_acquisition: Optional[AcquisitionFunction] = None self._fantasized_models: Mapping[ Tag, _fantasized_model | ModelStack[SupportsPredictJoint] ] = {} def _update_base_acquisition_function( self, models: Mapping[Tag, FantasizerModelOrStack], datasets: Optional[Mapping[Tag, Dataset]], ) -> AcquisitionFunction: if self._base_acquisition_function is not None: self._base_acquisition_function = self._builder.update_acquisition_function( self._base_acquisition_function, models, datasets ) else: self._base_acquisition_function = self._builder.prepare_acquisition_function( models, datasets ) return self._base_acquisition_function def _update_fantasized_acquisition_function( self, models: Mapping[Tag, FantasizerModelOrStack], datasets: Optional[Mapping[Tag, Dataset]], pending_points: TensorType, ) -> AcquisitionFunction: tf.debugging.assert_rank(pending_points, 2) fantasized_data = { tag: _generate_fantasized_data( fantasize_method=self._fantasize_method, model=model, pending_points=pending_points, ) for tag, model in models.items() } if datasets is None: datasets = fantasized_data else: datasets = {tag: data + fantasized_data[tag] for tag, data in datasets.items()} if self._fantasized_acquisition is None: self._fantasized_models = { tag: _generate_fantasized_model(model, fantasized_data[tag]) for tag, model in models.items() } self._fantasized_acquisition = self._builder.prepare_acquisition_function( cast(Dict[Tag, SupportsPredictJoint], self._fantasized_models), datasets ) else: for tag, model in self._fantasized_models.items(): if isinstance(model, ModelStack): observations = tf.split( fantasized_data[tag].observations, model._event_sizes, axis=-1 ) for submodel, obs in zip(model._models, observations): submodel.update_fantasized_data( Dataset(fantasized_data[tag].query_points, obs) ) else: model.update_fantasized_data(fantasized_data[tag]) self._builder.update_acquisition_function( self._fantasized_acquisition, cast(Dict[Tag, SupportsPredictJoint], self._fantasized_models), datasets, ) return self._fantasized_acquisition
[docs] def prepare_acquisition_function( self, models: Mapping[Tag, FantasizerModelOrStack], datasets: Optional[Mapping[Tag, Dataset]] = None, pending_points: Optional[TensorType] = None, ) -> AcquisitionFunction: """ :param models: The models over each tag. :param datasets: The data from the observer (optional). :param pending_points: Points already chosen to be in the current batch (of shape [M,D]), where M is the number of pending points and D is the search space dimension. :return: An acquisition function. """ for model in models.values(): if not ( isinstance(model, FantasizerModelType) or isinstance(model, ModelStack) and all(isinstance(m, FantasizerModelType) for m in model._models) ): raise NotImplementedError( f"Fantasizer only works with FastUpdateModel models that also support " f"predict_joint, get_kernel and get_observation_noise, or with " f"ModelStack stacks of such models; received {model!r}" ) if pending_points is None: return self._update_base_acquisition_function(models, datasets) else: return self._update_fantasized_acquisition_function(models, datasets, pending_points)
[docs] def update_acquisition_function( self, function: AcquisitionFunction, models: Mapping[Tag, FantasizerModelOrStack], datasets: Optional[Mapping[Tag, Dataset]] = None, pending_points: Optional[TensorType] = None, new_optimization_step: bool = True, ) -> AcquisitionFunction: """ :param function: The acquisition function to update. :param models: The models over each tag. :param datasets: The data from the observer (optional). :param pending_points: Points already chosen to be in the current batch (of shape [M,D]), where M is the number of pending points and D is the search space dimension. :param new_optimization_step: Indicates whether this call to update_acquisition_function is to start of a new optimization step, of to continue collecting batch of points for the current step. Defaults to ``True``. :return: The updated acquisition function. """ if pending_points is None: return self._update_base_acquisition_function(models, datasets) else: return self._update_fantasized_acquisition_function(models, datasets, pending_points)
[docs] def _generate_fantasized_data( fantasize_method: str, model: FantasizerModelOrStack, pending_points: TensorType ) -> Dataset: """ Generates "fantasized" data at pending_points depending on the chosen heuristic: - KB (kriging believer) uses the mean prediction of the models - sample uses samples from the GP posterior. :param fantasize_method: the following options are available: "KB" and "sample". :param model: a model with predict method :param dataset: past data :param pending_points: points at which to fantasize data :return: a fantasized dataset """ if fantasize_method == "KB": fantasized_obs, _ = model.predict(pending_points) elif fantasize_method == "sample": fantasized_obs = model.sample(pending_points, num_samples=1)[0] else: raise NotImplementedError(f"fantasize_method must be KB or sample, received {model!r}") return Dataset(pending_points, fantasized_obs)
def _generate_fantasized_model( model: FantasizerModelOrStack, fantasized_data: Dataset ) -> _fantasized_model | PredictJointPredictYModelStack: if isinstance(model, ModelStack): observations = tf.split(fantasized_data.observations, model._event_sizes, axis=-1) fmods = [] for mod, obs, event_size in zip(model._models, observations, model._event_sizes): fmods.append( ( _fantasized_model(mod, Dataset(fantasized_data.query_points, obs)), event_size, ) ) return PredictJointPredictYModelStack(*fmods) else: return _fantasized_model(model, fantasized_data)
[docs] class _fantasized_model( SupportsPredictJoint, SupportsGetKernel, SupportsGetObservationNoise, SupportsPredictY ): """ Creates a new model from an existing one and additional data. This new model posterior is conditioned on both current model data and the additional one. """ def __init__(self, model: FantasizerModelType, fantasized_data: Dataset): """ :param model: a model, must be of class `FastUpdateModel` :param fantasized_data: additional dataset to condition on :raise NotImplementedError: If model is not of class `FastUpdateModel`. """ self._model = model self._fantasized_query_points = tf.Variable( fantasized_data.query_points, trainable=False, shape=[None, *fantasized_data.query_points.shape[1:]], ) self._fantasized_observations = tf.Variable( fantasized_data.observations, trainable=False, shape=[None, *fantasized_data.observations.shape[1:]], )
[docs] def update_fantasized_data(self, fantasized_data: Dataset) -> None: """ :param fantasized_data: new additional dataset to condition on """ self._fantasized_query_points.assign(fantasized_data.query_points) self._fantasized_observations.assign(fantasized_data.observations)
@check_shapes( "query_points: [batch..., N, D]", "return[0]: [batch..., ..., N, L]", "return[1]: [batch..., ..., N, L]", )
[docs] def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]: """ This function wraps conditional_predict_f. It cannot directly call conditional_predict_f, since it does not accept query_points with rank > 2. We use map_fn to allow leading dimensions for query_points. :param query_points: shape [...*, N, d] :return: mean, shape [...*, ..., N, L] and cov, shape [...*, ..., N, L], where ... are the leading dimensions of fantasized_data """ def fun(qp: TensorType) -> tuple[TensorType, TensorType]: # pragma: no cover (tf.map_fn) fantasized_data = Dataset( self._fantasized_query_points.value(), self._fantasized_observations.value() ) return self._model.conditional_predict_f(qp, fantasized_data) return _broadcast_predict(query_points, fun)
@check_shapes( "query_points: [batch..., N, D]", "return[0]: [batch..., ..., N, L]", "return[1]: [batch..., ..., L, N, N]", )
[docs] def predict_joint(self, query_points: TensorType) -> tuple[TensorType, TensorType]: """ This function wraps conditional_predict_joint. It cannot directly call conditional_predict_joint, since it does not accept query_points with rank > 2. We use map_fn to allow leading dimensions for query_points. :param query_points: shape [...*, N, D] :return: mean, shape [...*, ..., N, L] and cov, shape [...*, ..., L, N, N], where ... are the leading dimensions of fantasized_data """ def fun(qp: TensorType) -> tuple[TensorType, TensorType]: # pragma: no cover (tf.map_fn) fantasized_data = Dataset( self._fantasized_query_points.value(), self._fantasized_observations.value() ) return self._model.conditional_predict_joint(qp, fantasized_data) return _broadcast_predict(query_points, fun)
@check_shapes( "query_points: [batch..., N, D]", "return: [batch..., ..., S, N, L]", )
[docs] def sample(self, query_points: TensorType, num_samples: int) -> TensorType: """ This function wraps conditional_predict_f_sample. It cannot directly call conditional_predict_joint, since it does not accept query_points with rank > 2. We use map_fn to allow leading dimensions for query_points. :param query_points: shape [...*, N, D] :param num_samples: number of samples. :return: samples of shape [...*, ..., S, N, L], where ... are the leading dimensions of fantasized_data """ leading_dim, query_points_flatten = _get_leading_dim_and_flatten(query_points) # query_points_flatten: [B, n, d], leading_dim =...*, product = B samples = tf.map_fn( fn=lambda qp: self._model.conditional_predict_f_sample( qp, Dataset( self._fantasized_query_points.value(), self._fantasized_observations.value() ), num_samples, ), elems=query_points_flatten, ) # [B, ..., S, L] return _restore_leading_dim(samples, leading_dim)
@check_shapes( "query_points: [broadcast batch..., N, D]", "return[0]: [batch..., ..., N, L]", "return[1]: [batch..., ..., N, L]", )
[docs] def predict_y(self, query_points: TensorType) -> tuple[TensorType, TensorType]: """ This function wraps conditional_predict_y. It cannot directly call conditional_predict_joint, since it does not accept query_points with rank > 2. We use tf.map_fn to allow leading dimensions for query_points. :param query_points: shape [...*, N, D] :return: mean, shape [...*, ..., N, L] and var, shape [...*, ..., N, L], where ... are the leading dimensions of fantasized_data """ def fun(qp: TensorType) -> tuple[TensorType, TensorType]: # pragma: no cover (tf.map_fn) fantasized_data = Dataset( self._fantasized_query_points.value(), self._fantasized_observations.value() ) return self._model.conditional_predict_y(qp, fantasized_data) return _broadcast_predict(query_points, fun)
[docs] def get_observation_noise(self) -> TensorType: return self._model.get_observation_noise()
[docs] def get_kernel(self) -> gpflow.kernels.Kernel: return self._model.get_kernel()
[docs] def log(self, dataset: Optional[Dataset] = None) -> None: return self._model.log(dataset)
[docs] def _broadcast_predict( query_points: TensorType, fun: Callable[[TensorType], tuple[TensorType, TensorType]] ) -> tuple[TensorType, TensorType]: """ Utility function that allows leading dimensions for query_points when fun only accepts rank 2 tensors. It works by flattening query_points into a rank 3 tensor, evaluate fun(query_points) through tf.map_fn, then restoring the leading dimensions. :param query_points: shape [...*, N, D] :param fun: callable that returns two tensors (e.g. a predict function) :return: two tensors (e.g. mean and variance) with shape [...*, ...] """ leading_dim, query_points_flatten = _get_leading_dim_and_flatten(query_points) # leading_dim =...*, product = B # query_points_flatten: [B, N, D] mean_signature = tf.TensorSpec(None, query_points.dtype) var_signature = tf.TensorSpec(None, query_points.dtype) mean, var = tf.map_fn( fn=fun, elems=query_points_flatten, fn_output_signature=(mean_signature, var_signature), ) # [B, ..., L, N], [B, ..., L, N] (predict_f) or [B, ..., L, N, N] (predict_joint) return _restore_leading_dim(mean, leading_dim), _restore_leading_dim(var, leading_dim)
[docs] def _get_leading_dim_and_flatten(query_points: TensorType) -> tuple[TensorType, TensorType]: """ :param query_points: shape [...*, N, D] :return: leading_dim = ....*, query_points_flatten, shape [B, N, D] """ leading_dim = tf.shape(query_points)[:-2] # =...*, product = B nd = tf.shape(query_points)[-2:] query_points_flatten = tf.reshape(query_points, (-1, nd[0], nd[1])) # [B, N, D] return leading_dim, query_points_flatten
[docs] def _restore_leading_dim(x: TensorType, leading_dim: TensorType) -> TensorType: """ "Un-flatten" the first dimension of x to leading_dim :param x: shape [B, ...] :param leading_dim: [...*] :return: shape [...*, ...] """ single_x_shape = tf.shape(x[0]) # = [...] output_x_shape = tf.concat([leading_dim, single_x_shape], axis=0) # = [...*, ...] return tf.reshape(x, output_x_shape) # [...*, ...]