trieste.models.optimizer#

This module contains common optimizers based on Optimizer that can be used with models. Specific models can also sub-class these optimizers or implement their own, and should register their loss functions using a create_loss_function().

Module Contents#

TrainingData[source]#

Type alias for a batch, or batches, of training data.

DatasetTransformer[source]#

Type alias for a function that converts a Dataset to batches of training data.

LossClosure[source]#

Type alias for a loss closure, typically used in optimization.

OptimizeResult[source]#

Optimization result. TensorFlow optimizer doesn’t return any result. For scipy optimizer that is also commonly used, it is OptimizeResult.

class Optimizer[source]#

Optimizer for training models with all the training data at once.

optimizer :Any[source]#

The underlying optimizer to use. For example, one of the subclasses of Optimizer could be used. Note that we use a flexible type Any to allow for various optimizers that specific models might need to use.

minimize_args :dict[str, Any][source]#

The keyword arguments to pass to the minimize() method of the optimizer.

compile :bool = False[source]#

If True, the optimization process will be compiled with function().

create_loss(model: tensorflow.Module, dataset: trieste.data.Dataset)LossClosure[source]#

Build a loss function for the specified model with the dataset using a create_loss_function().

Parameters
  • model – The model to build a loss function for.

  • dataset – The data with which to build the loss function.

Returns

The loss function.

optimize(model: tensorflow.Module, dataset: trieste.data.Dataset)OptimizeResult[source]#

Optimize the specified model with the dataset.

Parameters
  • model – The model to optimize.

  • dataset – The data with which to optimize the model.

Returns

The return value of the optimizer’s minimize() method.

class BatchOptimizer[source]#

Bases: Optimizer

Optimizer for training models with mini-batches of training data.

max_iter :int = 100[source]#

The number of iterations over which to optimize the model.

batch_size :int = 100[source]#

The size of the mini-batches.

dataset_builder :DatasetTransformer | None[source]#

A mapping from Observer data to mini-batches.

create_loss(model: tensorflow.Module, dataset: trieste.data.Dataset)LossClosure[source]#

Build a loss function for the specified model with the dataset.

Parameters
  • model – The model to build a loss function for.

  • dataset – The data with which to build the loss function.

Returns

The loss function.

optimize(model: tensorflow.Module, dataset: trieste.data.Dataset)None[source]#

Optimize the specified model with the dataset.

Parameters
  • model – The model to optimize.

  • dataset – The data with which to optimize the model.

class KerasOptimizer[source]#

Optimizer wrapper for training models implemented with Keras.

optimizer :tensorflow.keras.optimizers.Optimizer[source]#

The underlying optimizer to use for training the model.

fit_args :dict[str, Any][source]#

The keyword arguments to pass to the fit method of a Model instance. See https://keras.io/api/models/model_training_apis/#fit-method for a list of possible arguments in the dictionary.

loss :Optional[Union[tensorflow.keras.losses.Loss, Callable[[trieste.types.TensorType, tensorflow_probability.distributions.Distribution], trieste.types.TensorType]]][source]#

Optional loss function for training the model.

metrics :Optional[list[tensorflow.keras.metrics.Metric]][source]#

Optional metrics for monitoring the performance of the network.

create_loss_function(model: Any, dataset: TrainingData, compile: bool = False)LossClosure[source]#

Generic function for building a loss function for a specified model and dataset. The implementations depends on the type of the model, which should use this function as a decorator together with its register method to make a model-specific loss function available.

Parameters
  • model – The model to build a loss function for.

  • dataset – The data with which to build the loss function.

  • compile – Whether to compile with tf.function().

Returns

The loss function.