Source code for gpflux.layers.basis_functions.fourier_features.quadrature.gaussian

#
# Copyright (c) 2021 The GPflux Contributors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Kernel decompositon into features and coefficients based on Gauss-Christoffel
quadrature aka Gaussian quadrature.
"""

import warnings
from typing import Mapping, Tuple, Type

import tensorflow as tf

import gpflow
from gpflow.base import TensorType
from gpflow.quadrature.gauss_hermite import ndgh_points_and_weights

from gpflux.layers.basis_functions.fourier_features.base import FourierFeaturesBase
from gpflux.layers.basis_functions.fourier_features.utils import _bases_concat
from gpflux.types import ShapeType

"""
Kernels supported by :class:`QuadratureFourierFeatures`.

Currently we only support the :class:`gpflow.kernels.SquaredExponential` kernel.
For Matern kernels please use :class:`RandomFourierFeatures`
or :class:`RandomFourierFeaturesCosine`.
"""
QFF_SUPPORTED_KERNELS: Tuple[Type[gpflow.kernels.Stationary], ...] = (
    gpflow.kernels.SquaredExponential,
)


[docs]class QuadratureFourierFeatures(FourierFeaturesBase): def __init__(self, kernel: gpflow.kernels.Kernel, n_components: int, **kwargs: Mapping): assert isinstance(kernel, QFF_SUPPORTED_KERNELS), "Unsupported Kernel" if tf.reduce_any(tf.less(kernel.lengthscales, 1e-1)): warnings.warn( "Quadrature Fourier feature approximation of kernels " "with small lengthscale lead to unexpected behaviors!" ) super(QuadratureFourierFeatures, self).__init__(kernel, n_components, **kwargs)
[docs] def build(self, input_shape: ShapeType) -> None: """ Creates the variables of the layer. See `tf.keras.layers.Layer.build() <https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#build>`_. """ input_dim = input_shape[-1] abscissa_value, omegas_value = ndgh_points_and_weights( dim=input_dim, n_gh=self.n_components ) omegas_value = tf.squeeze(omegas_value, axis=-1) # Quadrature node points self.abscissa = tf.Variable(initial_value=abscissa_value, trainable=False) # (M^D, D) # Gauss-Hermite weights self.factors = tf.Variable(initial_value=omegas_value, trainable=False) # (M^D,) super(QuadratureFourierFeatures, self).build(input_shape)
def _compute_output_dim(self, input_shape: ShapeType) -> int: input_dim = input_shape[-1] return 2 * self.n_components ** input_dim
[docs] def _compute_bases(self, inputs: TensorType) -> tf.Tensor: """ Compute basis functions. :return: A tensor with the shape ``[N, 2M^D]``. """ return _bases_concat(inputs, self.abscissa)
[docs] def _compute_constant(self) -> tf.Tensor: """ Compute normalizing constant for basis functions. :return: A tensor with the shape ``[2M^D,]`` """ return tf.tile(tf.sqrt(self.kernel.variance * self.factors), multiples=[2])