trieste.acquisition.utils#

Module Contents#

split_acquisition_function(fn: trieste.acquisition.interface.AcquisitionFunction, split_size: int)trieste.acquisition.interface.AcquisitionFunction[source]#

A wrapper around an AcquisitionFunction to split its input into batches. Splits x into batches along the first dimension, calls fn on each batch, and then stitches the results back together, so that it looks like fn was called with all of x in one batch. :param fn: Acquisition function to split. :param split_size: Call fn with tensors of at most this size. :return: Split acquisition function.

split_acquisition_function_calls(optimizer: trieste.acquisition.optimizer.AcquisitionOptimizer[trieste.space.SearchSpaceType], split_size: int)trieste.acquisition.optimizer.AcquisitionOptimizer[trieste.space.SearchSpaceType][source]#

A wrapper around our AcquisitionOptimizer`s. This class wraps a :const:`AcquisitionOptimizer so that evaluations of the acquisition functions are split into batches on the first dimension and then stitched back together. This can be useful to reduce memory usage when evaluating functions over large spaces.

Parameters
  • optimizer – An optimizer that returns batches of points with shape [V, …].

  • split_size – The desired maximum number of points in acquisition function evaluations.

Returns

An AcquisitionOptimizer that still returns points with the shape [V, …] but evaluates at most split_size points at a time.

select_nth_output(x: trieste.types.TensorType, output_dim: int = 0)trieste.types.TensorType[source]#

A utility function for trajectory sampler-related acquisition functions which selects the n`th output as the trajectory to be used, with `n specified by output_dim. Defaults to the first output.

Parameters
  • x – Input with shape […, B, L], where L is the number of outputs of the model.

  • output_dim – Dimension of the output to be selected. Defaults to the first output.

Returns

TensorType with shape […, B], where the output_dim dimension has been selected to reduce the input.