Source code for gpflux.layers.bayesian_dense_layer

#
# Copyright (c) 2021 The GPflux Contributors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
This module provides :class:`BayesianDenseLayer`, which implements a
variational Bayesian dense (fully-connected) neural network layer as a Keras
:class:`~tf.keras.layers.Layer`.
"""

from typing import Callable, Optional, Union

import numpy as np
import tensorflow as tf

from gpflow import Parameter, default_float
from gpflow.base import TensorType
from gpflow.kullback_leiblers import gauss_kl
from gpflow.models.model import MeanAndVariance
from gpflow.utilities.bijectors import positive, triangular

from gpflux.helpers import xavier_initialization_numpy
from gpflux.layers.trackable_layer import TrackableLayer
from gpflux.types import ShapeType


[docs]class BayesianDenseLayer(TrackableLayer): """ A dense (fully-connected) layer for variational Bayesian neural networks. This layer holds the mean and square-root of the variance of the distribution over the weights. This layer also has a temperature for cooling (or heating) the posterior. """ def __init__( self, input_dim: int, output_dim: int, num_data: int, w_mu: Optional[np.ndarray] = None, w_sqrt: Optional[np.ndarray] = None, activation: Optional[Callable] = None, is_mean_field: bool = True, temperature: float = 1e-4, ): """ :param input_dim: The input dimension (excluding bias) of this layer. :param output_dim: The output dimension of this layer. :param num_data: The number of points in the training dataset (used for scaling the KL regulariser). :param w_mu: Initial value of the variational mean for weights + bias. If not specified, this defaults to `xavier_initialization_numpy` for the weights and zero for the bias. :param w_sqrt: Initial value of the variational Cholesky of the (co)variance for weights + bias. If not specified, this defaults to 1e-5 * Identity. :param activation: The activation function. If not specified, this defaults to the identity. :param is_mean_field: Determines whether the approximation to the weight posterior is mean field. Must be consistent with the shape of ``w_sqrt``, if specified. :param temperature: The KL loss will be scaled by this factor. Can be used for cooling (< 1.0) or heating (> 1.0) the posterior. As suggested in `"How Good is the Bayes Posterior in Deep Neural Networks Really?" by Wenzel et al. (2020) <http://proceedings.mlr.press/v119/wenzel20a>`_ the default value is a cold ``1e-4``. """ super().__init__(dtype=default_float()) assert input_dim >= 1 assert output_dim >= 1 assert num_data >= 1 if w_mu is not None: # add + 1 for the bias assert w_mu.shape == ((input_dim + 1) * output_dim,) if w_sqrt is not None: if not is_mean_field: assert w_sqrt.shape == ( (input_dim + 1) * output_dim, (input_dim + 1) * output_dim, ) else: assert w_sqrt.shape == ((input_dim + 1) * output_dim,) assert temperature > 0.0 self.input_dim = input_dim self.output_dim = output_dim self.num_data = num_data self.w_mu_ini = w_mu self.w_sqrt_ini = w_sqrt self.activation = activation self.is_mean_field = is_mean_field self.temperature = temperature self.dim = (input_dim + 1) * output_dim self.full_output_cov = False self.full_cov = False self.w_mu = Parameter(np.zeros((self.dim,)), dtype=default_float(), name="w_mu") # [dim] self.w_sqrt = Parameter( np.zeros((self.dim, self.dim)) if not self.is_mean_field else np.ones((self.dim,)), transform=triangular() if not self.is_mean_field else positive(), dtype=default_float(), name="w_sqrt", ) # [dim, dim] or [dim] def initialize_variational_distribution(self) -> None: if self.w_mu_ini is None: w = xavier_initialization_numpy(self.input_dim, self.output_dim) b = np.zeros((1, self.output_dim)) self.w_mu_ini = np.concatenate((w, b), axis=0).reshape((self.dim,)) self.w_mu.assign(self.w_mu_ini) if self.w_sqrt_ini is None: if not self.is_mean_field: self.w_sqrt_ini = 1e-5 * np.eye(self.dim) else: self.w_sqrt_ini = 1e-5 * np.ones((self.dim,)) self.w_sqrt.assign(self.w_sqrt_ini)
[docs] def build(self, input_shape: ShapeType) -> None: """Build the variables necessary on first call""" super().build(input_shape) self.initialize_variational_distribution()
[docs] def predict_samples( self, inputs: TensorType, *, num_samples: Optional[int] = None, ) -> tf.Tensor: """ Samples from the approximate posterior at N test inputs, with input_dim = D, output_dim = Q. :param inputs: The inputs to predict at; shape ``[N, D]``. :param num_samples: The number of samples S, to draw. :returns: Samples, shape ``[S, N, Q]`` if S is not None else ``[N, Q]``. """ _num_samples = num_samples or 1 z = tf.random.normal((self.dim, _num_samples), dtype=default_float()) # [dim, S] if not self.is_mean_field: w = self.w_mu[:, None] + tf.matmul(self.w_sqrt, z) # [dim, S] else: w = self.w_mu[:, None] + self.w_sqrt[:, None] * z # [dim, S] N = tf.shape(inputs)[0] inputs_concat_1 = tf.concat( (inputs, tf.ones((N, 1), dtype=default_float())), axis=-1 ) # [N, D+1] samples = tf.tensordot( inputs_concat_1, tf.reshape(tf.transpose(w), (_num_samples, self.input_dim + 1, self.output_dim)), [[-1], [1]], ) # [N, S, Q] if num_samples is None: samples = tf.squeeze(samples, axis=-2) # [N, Q] else: samples = tf.transpose(samples, perm=[1, 0, 2]) # [S, N, Q] if self.activation is not None: samples = self.activation(samples) return samples
[docs] def call( self, inputs: TensorType, training: Optional[bool] = False ) -> Union[tf.Tensor, MeanAndVariance]: """ The default behaviour upon calling this layer. """ sample = self.predict_samples( inputs, num_samples=None, ) # TF quirk: add_loss must add a tensor to compile if training: loss = self.temperature * self.prior_kl() else: loss = tf.constant(0.0, dtype=default_float()) loss_per_datapoint = loss / self.num_data self.add_loss(loss_per_datapoint) return sample # [N, Q]
[docs] def prior_kl(self) -> tf.Tensor: """ Returns the KL divergence ``KL[q(u)∥p(u)]`` from the prior ``p(u) = N(0, I)`` to the variational distribution ``q(u) = N(w_mu, w_sqrt²)``. """ return gauss_kl( self.w_mu[:, None], self.w_sqrt[None] if not self.is_mean_field else self.w_sqrt[:, None], )