
Module Contents#

class GPfluxPredictor(optimizer: trieste.models.optimizer.KerasOptimizer | None = None)[source]#

Bases: trieste.models.interfaces.SupportsGetObservationNoise, abc.ABC

A trainable wrapper for a GPflux deep Gaussian process model. The code assumes subclasses will use the Keras fit method for training, and so they should provide access to both a model_keras and model_gpflux.


optimizer – The optimizer wrapper containing the optimizer with which to train the model and arguments for the wrapper and the optimizer. The optimizer must be an instance of a Optimizer. Defaults to Adam optimizer with 0.01 learning rate.

abstract property model_gpflux: gpflow.base.Module[source]#

The underlying GPflux model.

abstract property model_keras: tensorflow.keras.Model[source]#

Returns the compiled Keras model for training.

property optimizer: trieste.models.optimizer.KerasOptimizer[source]#

The optimizer wrapper for training the model.

predict(query_points: trieste.types.TensorType) tuple[trieste.types.TensorType, trieste.types.TensorType][source]#

Note: unless otherwise noted, this returns the mean and variance of the last layer conditioned on one sample from the previous layers.

abstract sample(query_points: trieste.types.TensorType, num_samples: int) trieste.types.TensorType[source]#

Return num_samples samples from the independent marginal distributions at query_points.

  • query_points – The points at which to sample, with shape […, N, D].

  • num_samples – The number of samples at each point.


The samples. For a predictive distribution with event shape E, this has shape […, S, N] + E, where S is the number of samples.

predict_y(query_points: trieste.types.TensorType) tuple[trieste.types.TensorType, trieste.types.TensorType][source]#

Note: unless otherwise noted, this will return the prediction conditioned on one sample from the lower layers.

get_observation_noise() trieste.types.TensorType[source]#

Return the variance of observation noise for homoscedastic likelihoods.


The observation noise.


NotImplementedError – If the model does not have a homoscedastic likelihood.