Source code for gpflux.layers.basis_functions.fourier_features.base

#
# Copyright (c) 2021 The GPflux Contributors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
""" Shared functionality for stationary kernel basis functions. """

from abc import ABC, abstractmethod
from typing import Mapping

import tensorflow as tf

import gpflow
from gpflow.base import TensorType

from gpflux.types import ShapeType


[docs]class FourierFeaturesBase(ABC, tf.keras.layers.Layer): r""" The base class for all Fourier feature layers, used for both random Fourier feature layers and quadrature layers. We subclass :class:`tf.keras.layers.Layer`, so we must provide :method:`build` and :method:`call` methods. """ def __init__(self, kernel: gpflow.kernels.Kernel, n_components: int, **kwargs: Mapping): """ :param kernel: kernel to approximate using a set of Fourier bases. :param n_components: number of components (e.g. Monte Carlo samples, quadrature nodes, etc.) used to numerically approximate the kernel. """ super(FourierFeaturesBase, self).__init__(**kwargs) self.kernel = kernel self.n_components = n_components if isinstance(kernel, gpflow.kernels.MultioutputKernel): self.is_multioutput = True self.num_latent_gps = kernel.num_latent_gps else: self.is_multioutput = False self.num_latent_gps = 1 if kwargs.get("input_dim", None): self._input_dim = kwargs["input_dim"] self.build(tf.TensorShape([self._input_dim])) else: self._input_dim = None
[docs] def call(self, inputs: TensorType) -> tf.Tensor: """ Evaluate the basis functions at ``inputs``. :param inputs: The evaluation points, a tensor with the shape ``[N, D]``. :return: A tensor with the shape ``[N, M]``, or shape ``[P, N, M]'' in the multioutput case. """ if self.is_multioutput: X = [tf.divide(inputs, k.lengthscales) for k in self.kernel.latent_kernels] X = tf.stack(X, 0) # [1, N, D] or [P, N, D] else: X = tf.divide(inputs, self.kernel.lengthscales) # [N, D] const = self._compute_constant() # [] or [P, 1, 1] bases = self._compute_bases(X) # [N, M] or [P, N, M] output = const * bases tf.ensure_shape(output, self.compute_output_shape(inputs.shape)) return output
[docs] def compute_output_shape(self, input_shape: ShapeType) -> tf.TensorShape: """ Computes the output shape of the layer. See `tf.keras.layers.Layer.compute_output_shape() <https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#compute_output_shape>`_. """ # TODO: Keras docs say "If the layer has not been built, this method # will call `build` on the layer." -- do we need to do so? tensor_shape = tf.TensorShape(input_shape).with_rank(2) output_dim = self._compute_output_dim(input_shape) trailing_shape = tensor_shape[:-1].concatenate(output_dim) if self.is_multioutput: return tf.TensorShape([self.num_latent_gps]).concatenate(trailing_shape) # [P, N, M] else: return trailing_shape # [N, M]
[docs] def get_config(self) -> Mapping: """ Returns the config of the layer. See `tf.keras.layers.Layer.get_config() <https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#get_config>`_. """ config = super(FourierFeaturesBase, self).get_config() config.update( { "kernel": self.kernel, "n_components": self.n_components, "input_dim": self._input_dim, } ) return config
@abstractmethod def _compute_output_dim(self, input_shape: ShapeType) -> int: pass @abstractmethod
[docs] def _compute_constant(self) -> tf.Tensor: """ Compute normalizing constant for basis functions. """ pass
@abstractmethod
[docs] def _compute_bases(self, inputs: TensorType) -> tf.Tensor: """ Compute basis functions. """ pass