Source code for gpflux.layers.latent_variable_layer

#
# Copyright (c) 2021 The GPflux Contributors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
""" This module implements a latent variable layer for deep GPs. """

import abc
from typing import Optional, Tuple

import tensorflow as tf
import tensorflow_probability as tfp

from gpflow import default_float
from gpflow.base import TensorType
from gpflow.keras import tf_keras

from gpflux.layers.trackable_layer import TrackableLayer
from gpflux.types import ObservationType


[docs]class LayerWithObservations(TrackableLayer, metaclass=abc.ABCMeta): """ By inheriting from this class, Layers indicate that their :meth:`call` method takes a second *observations* argument after the customary *layer_inputs* argument. This is used to distinguish which layers (unlike most standard Keras layers) require the original inputs and/or targets during training. For example, it is used by the amortized variational inference in the :class:`LatentVariableLayer`. """ @abc.abstractmethod
[docs] def call( self, layer_inputs: TensorType, observations: Optional[ObservationType] = None, training: Optional[bool] = None, ) -> tf.Tensor: """ The :meth:`call` method of `LayerWithObservations` subclasses should accept a second argument, *observations*. In training mode, this will be the ``[inputs, targets]`` of the training points; otherwise, it is `None`. """
[docs]class LatentVariableLayer(LayerWithObservations): """ A latent variable layer, with amortized mean-field variational inference. The latent variable is distribution-agnostic, but assumes a variational posterior that is fully factorised and is of the same distribution family as the prior. This class is used by models as described in :cite:p:`dutordoir2018cde, salimbeni2019iwvi`. """
[docs] prior: tfp.distributions.Distribution
""" The prior distribution for the latent variables. """
[docs] encoder: tf_keras.layers.Layer
""" An encoder that maps from a concatenation of inputs and targets to the parameters of the approximate posterior distribution of the corresponding latent variables. """
[docs] compositor: tf_keras.layers.Layer
""" A layer that takes as input the two-element ``[layer_inputs, latent_variable_samples]`` list and combines the elements into a single output tensor. """ def __init__( self, prior: tfp.distributions.Distribution, encoder: tf_keras.layers.Layer, compositor: Optional[tf_keras.layers.Layer] = None, name: Optional[str] = None, ): """ :param prior: A distribution that represents the :attr:`prior` over the latent variable. :param encoder: A layer which is passed the concatenated observation inputs and targets, and returns the appropriate parameters for the approximate posterior distribution; see :attr:`encoder`. :param compositor: A layer that combines layer inputs and latent variable samples into a single tensor; see :attr:`compositor`. If you do not specify a value for this parameter, the default is ``tf.keras.layers.Concatenate(axis=-1, dtype=default_float())``. Note that you should set ``dtype`` of the layer to GPflow's default dtype as in :meth:`~gpflow.default_float()`. :param name: The name of this layer (passed through to `tf.keras.layers.Layer`). """ super().__init__(dtype=default_float(), name=name) self.prior = prior self.distribution_class = prior.__class__ self.encoder = encoder self.compositor = ( compositor if compositor is not None else tf_keras.layers.Concatenate(axis=-1, dtype=default_float()) )
[docs] def call( self, layer_inputs: TensorType, observations: Optional[ObservationType] = None, training: Optional[bool] = None, seed: Optional[int] = None, ) -> tf.Tensor: r""" Sample the latent variables and compose them with the layer input. When training, draw a sample of the latent variable from the posterior, whose distribution is parameterised by the encoder mapping from the data. Also add a KL divergence [posterior∥prior] to the losses. When not training, draw a sample of the latent variable from the prior. :param layer_inputs: The output of the previous layer. :param observations: The ``[inputs, targets]``, with the shapes ``[batch size, Din]`` and ``[batch size, Dout]`` respectively. This parameter should be passed only when in training mode. :param training: The training mode indicator. :param seed: A random seed for the sampling operation. :returns: Samples of the latent variable composed with the layer inputs through the :attr:`compositor` """ if training: if observations is None: raise ValueError("LatentVariableLayer requires observations when training") samples, loss_per_datapoint = self._inference_latent_samples_and_loss( layer_inputs, observations, seed=seed ) else: samples = self._prediction_latent_samples(layer_inputs, seed=seed) loss_per_datapoint = tf.constant(0.0, dtype=default_float()) self.add_loss(loss_per_datapoint) # Metric names should be unique; otherwise they get overwritten if you # have multiple with the same name name = f"{self.name}_local_kl" if self.name else "local_kl" self.add_metric(loss_per_datapoint, name=name, aggregation="mean") return self.compositor([layer_inputs, samples])
[docs] def _inference_posteriors( self, observations: ObservationType, training: Optional[bool] = None, ) -> tfp.distributions.Distribution: """ Return the posterior distributions parametrised by the :attr:`encoder`, which gets called with the concatenation of the inputs and targets in the *observations* argument. .. todo:: We might want to change encoders to have a `tfp.layers.DistributionLambda` final layer that directly returns the appropriately parameterised distributions object. :param observations: The ``[inputs, targets]``, with the shapes ``[batch size, Din]`` and ``[batch size, Dout]`` respectively. :param training: The training mode indicator (passed through to the :attr:`encoder`'s call). :returns: The posterior distributions object. """ inputs, targets = observations encoder_inputs = tf.concat(observations, axis=-1) distribution_params = self.encoder(encoder_inputs, training=training) posteriors = self.distribution_class(*distribution_params, allow_nan_stats=False) return posteriors
[docs] def _inference_latent_samples_and_loss( self, layer_inputs: TensorType, observations: ObservationType, seed: Optional[int] = None ) -> Tuple[tf.Tensor, tf.Tensor]: r""" Sample latent variables during the *training* forward pass, hence requiring the observations. Also return the KL loss per datapoint. :param layer_inputs: The output of the previous layer _(unused)_. :param observations: The ``[inputs, targets]``, with the shapes ``[batch size, Din]`` and ``[batch size, Dout]`` respectively. :param seed: A random seed for the sampling operation. :returns: The samples and the loss-per-datapoint. """ posteriors = self._inference_posteriors(observations, training=True) samples = posteriors.sample(seed=seed) # [N, Dw] # closed-form expectation E_q[log(q/p)] = KL[q∥p]: local_kls = self._local_kls(posteriors) loss_per_datapoint = tf.reduce_mean(local_kls, name="local_kls") return samples, loss_per_datapoint
[docs] def _prediction_latent_samples( self, layer_inputs: TensorType, seed: Optional[int] = None ) -> tf.Tensor: """ Sample latent variables during the *prediction* forward pass, only depending on the shape of this layer's inputs. :param layer_inputs: The output of the previous layer (for determining batch shape). :param seed: A random seed for the sampling operation. :returns: The samples. """ batch_shape = tf.shape(layer_inputs)[:-1] samples = self.prior.sample(batch_shape, seed=seed) return samples
[docs] def _local_kls(self, posteriors: tfp.distributions.Distribution) -> tf.Tensor: """ Compute the KL divergences [posteriors∥prior]. :param posteriors: A distribution that represents the approximate posteriors. :returns: The KL divergences from the prior for each of the posteriors. """ return posteriors.kl_divergence(self.prior)