Source code for gpflux.sampling.utils

#
# Copyright (c) 2021 The GPflux Contributors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
This module contains utilities for sampling from multivariate Gaussian distributions.
"""
import tensorflow as tf

from gpflow.base import TensorType
from gpflow.conditionals.util import sample_mvn

from gpflux.math import _cholesky_with_jitter


[docs]def draw_conditional_sample(mean: TensorType, cov: TensorType, f_old: TensorType) -> tf.Tensor: r""" Draw a sample :math:`\tilde{f}_\text{new}` from the conditional multivariate Gaussian :math:`p(f_\text{new} | f_\text{old})`, where the parameters ``mean`` and ``cov`` are the mean and covariance matrix of the joint multivariate Gaussian over :math:`[f_\text{old}, f_\text{new}]`. :param mean: A tensor with the shape ``[..., D, N+M]`` with the mean of ``[f_old, f_new]``. For each ``[..., D]`` this is a stacked vector of the form: .. math:: \begin{pmatrix} \operatorname{mean}(f_\text{old}) \;[N] \\ \operatorname{mean}(f_\text{new}) \;[M] \end{pmatrix} :param cov: A tensor with the shape ``[..., D, N+M, N+M]`` with the covariance of ``[f_old, f_new]``. For each ``[..., D]``, there is a 2x2 block matrix of the form: .. math:: \begin{pmatrix} \operatorname{cov}(f_\text{old}, f_\text{old}) \;[N, N] & \operatorname{cov}(f_\text{old}, f_\text{new}) \;[N, M] \\ \operatorname{cov}(f_\text{new}, f_\text{old}) \;[M, N] & \operatorname{cov}(f_\text{new}, f_\text{new}) \;[M, M] \end{pmatrix} :param f_old: A tensor of observations with the shape ``[..., D, N]``, drawn from Normal distribution with mean :math:`\operatorname{mean}(f_\text{old}) \;[N]`, and covariance :math:`\operatorname{cov}(f_\text{old}, f_\text{old}) \;[N, N]` :return: A sample :math:`\tilde{f}_\text{new}` from the conditional normal :math:`p(f_\text{new} | f_\text{old})` with the shape ``[..., D, M]``. """ N, D = tf.shape(f_old)[-1], tf.shape(f_old)[-2] # noqa: F841 M = tf.shape(mean)[-1] - N cov_old = cov[..., :N, :N] # [..., D, N, N] cov_new = cov[..., -M:, -M:] # [..., D, M, M] cov_cross = cov[..., :N, -M:] # [..., D, N, M] L_old = _cholesky_with_jitter(cov_old) # [..., D, N, N] A = tf.linalg.triangular_solve(L_old, cov_cross, lower=True) # [..., D, N, M] var_new = cov_new - tf.matmul(A, A, transpose_a=True) # [..., D, M, M] mean_new = mean[..., -M:] # [..., D, M] mean_old = mean[..., :N] # [..., D, N] mean_old_diff = (f_old - mean_old)[..., None] # [..., D, N, 1] AM = tf.linalg.triangular_solve(L_old, mean_old_diff) # [..., D, N, 1] mean_new = mean_new + (tf.matmul(A, AM, transpose_a=True)[..., 0]) # [..., D, M] return sample_mvn(mean_new, var_new, full_cov=True)