trieste.models.keras.sampler#
This module is the home of the sampling functionality required by some of the Trieste’s Keras model wrappers.
Module Contents#
- class DeepEnsembleTrajectorySampler(model: trieste.models.keras.interface.DeepEnsembleModel, diversify: bool = False, seed: int | None = None)[source]#
Bases:
trieste.models.interfaces.TrajectorySampler
[trieste.models.keras.interface.DeepEnsembleModel
]This class builds functions that approximate a trajectory by randomly choosing a network from the ensemble and using its predicted means as a trajectory.
Option diversify can be used to increase the diversity in case of optimizing very large batches of trajectories. We use quantiles from the approximate Gaussian distribution of the ensemble as trajectories, with randomly chosen quantiles approximating a trajectory and using a reparametrisation trick to speed up computation. Note that quantiles are not true trajectories, so this will likely have some performance costs.
- Parameters:
model – The ensemble model to sample from.
diversify – Whether to use quantiles from the approximate Gaussian distribution of the ensemble as trajectories (False by default). See class docstring for details.
seed – Random number seed to use for trajectory sampling.
- Raises:
NotImplementedError – If we try to use the model that is not instance of
DeepEnsembleModel
.
- get_trajectory() trieste.models.interfaces.TrajectoryFunction [source]#
Generate an approximate function draw (trajectory) from the ensemble.
- Returns:
A trajectory function representing an approximate trajectory from the model, taking an input of shape [N, B, D] and returning shape [N, B, L].
- update_trajectory(trajectory: trieste.models.interfaces.TrajectoryFunction) trieste.models.interfaces.TrajectoryFunction [source]#
Update a
TrajectoryFunction
to reflect an update in its underlyingDeepEnsembleModel
and resample accordingly.Here we rely on the underlying models being updated and we only resample the trajectory.
- Parameters:
trajectory – The trajectory function to be resampled.
- Returns:
The new trajectory function updated for a new model
- resample_trajectory(trajectory: trieste.models.interfaces.TrajectoryFunction) trieste.models.interfaces.TrajectoryFunction [source]#
Efficiently resample a
TrajectoryFunction
in-place to avoid function retracing with every new sample.- Parameters:
trajectory – The trajectory function to be resampled.
- Returns:
The new resampled trajectory function.
- class deep_ensemble_trajectory(model: trieste.models.keras.interface.DeepEnsembleModel, diversify: bool, seed: int | None = None)[source]#
Bases:
trieste.models.interfaces.TrajectoryFunctionClass
Generate an approximate function draw (trajectory) by randomly choosing a batch B of networks from the ensemble and using their predicted means as trajectories.
Option diversify can be used to increase the diversity in case of optimizing very large batches of trajectories. We use quantiles from the approximate Gaussian distribution of the ensemble as trajectories, with randomly chosen quantiles approximating a trajectory and using a reparametrisation trick to speed up computation. Note that quantiles are not true trajectories, so this will likely have some performance costs.
- Parameters:
model – The model of the objective function.
diversify – Whether to use samples from final probabilistic layer as trajectories or mean predictions.
seed – Optional RNG seed.
- __call__(x: trieste.types.TensorType) trieste.types.TensorType [source]#
Call trajectory function. Note that we are flattening the batch dimension and doing a forward pass with each network in the ensemble with the whole batch. This is somewhat wasteful, but is necessary given the underlying
KerasEnsemble
network model.