gpflux.optimization.keras_natgrad#
Support for the gpflow.optimizers.NaturalGradient
optimizer within Keras models.
Module Contents#
- class NatGradModel[source]#
Bases:
gpflow.keras.tf_keras.Model
This is a drop-in replacement for
tf.keras.Model
when constructing GPflux models using the functional Keras style, to make it work with the NaturalGradient optimizers for q(u) distributions in GP layers.You must set the
natgrad_layers
property before compiling the model. Set it to the list of allGPLayer
s you want to train using natural gradients. You can also set it toTrue
to include all of them.This model’s
compile()
method has to be passed a list of optimizers, which must be onegpflow.optimizers.NaturalGradient
instance per natgrad-trainedGPLayer
, followed by a regular optimizer (e.g.tf.keras.optimizers.Adam
) as the last element to handle all other parameters (hyperparameters, inducing point locations).- property natgrad_layers: List[gpflux.layers.gp_layer.GPLayer][source]#
The list of layers in this model that should be optimized using
~gpflow.optimizers.NaturalGradient
.
- property optimizer: gpflow.keras.tf_keras.optimizers.Optimizer[source]#
HACK to cope with Keras’s callbacks such as
ReduceLROnPlateau
andLearningRateScheduler
having been hardcoded for a single optimizer.
- train_step(data: Any) Mapping[str, Any] [source]#
The logic for one training step. For more details of the implementation, see TensorFlow’s documentation of how to customize what happens in Model.fit.
- class NatGradWrapper(base_model: gpflow.keras.tf_keras.Model, *args: Any, **kwargs: Any)[source]#
Bases:
NatGradModel
Wraps a class-based Keras model (e.g. the return value of
gpflux.models.DeepGP.as_training_model
) to make it work withgpflow.optimizers.NaturalGradient
optimizers. For more details, seeNatGradModel
.(Note that you can also directly pass
NatGradModel
to theDeepGP
’sdefault_model_class
oras_training_model()
’s model_class arguments.)Todo
This class will probably be removed in the future.
- Parameters:
base_model – the class-based Keras model to be wrapped