markovflow.models.pep
Module containing a model for CVI
PowerExpectationPropagation
Bases: markovflow.models.variational_cvi.GaussianProcessWithSitesBase
markovflow.models.variational_cvi.GaussianProcessWithSitesBase
This is an approximate inference called Power Expectation Propagation.
Approximates a the posterior of a model with GP prior and a general likelihood using a Gaussian posterior parameterized with Gaussian sites.
The following notation is used:
x - the time points of the training data. y - observations corresponding to time points x. s(.) - the latent state of the Markov chain f(.) - the noise free predictions of the model p(y | f) - the likelihood t(f) - a site (indices will refer to the associated data point) p(.) the prior distribution q(.) the variational distribution
x - the time points of the training data.
y - observations corresponding to time points x.
s(.) - the latent state of the Markov chain
f(.) - the noise free predictions of the model
p(y | f) - the likelihood
t(f) - a site (indices will refer to the associated data point)
p(.) the prior distribution
q(.) the variational distribution
We use the state space formulation of Markovian Gaussian Processes that specifies: the conditional density of neighbouring latent states: p(xₖ₊₁| xₖ) how to read out the latent process from these states: fₖ = H xₖ
The likelihood links data to the latent process and p(yₖ | fₖ). We would like to approximate the posterior over the latent state space model of this model.
We parameterize the joint posterior using sites tₖ(fₖ)
p(x, y) = p(x) ∏ₖ tₖ(fₖ)
where tₖ(fₖ) are univariate Gaussian sites parameterized in the natural form
t(f) = exp(𝞰ᵀφ(f) - A(𝞰)), where 𝞰=[η₁,η₂] and 𝛗(f)=[f,f²]
(note: the subscript k has been omitted for simplicity)
The site update of the sites are given by the classic EP update rules as described in:
title={Expectation propagation for exponential families}, author={Seeger, Matthias}, year={2005}
}
kernel – A kernel that defines a prior over functions.
input_data – A tuple of (time_points, observations) containing the observed data: time points of observations, with shape batch_shape + [num_data], observations with shape batch_shape + [num_data, observation_dim].
(time_points, observations)
batch_shape + [num_data]
batch_shape + [num_data, observation_dim]
likelihood – A likelihood. with shape batch_shape + [num_inducing].
batch_shape + [num_inducing]
mean_function – The mean function for the GP. Defaults to no mean function.
learning_rate – the learning rate of the algorithm
alpha – the power as in Power Expectation propagation
local_objective
Local objective of the PEP algorithm : log E_q(f) p(y|f)ᵃ
local_objective_gradients
Gradients of the local objective of the PEP algorithm wrt to the predictive mean
mask_indices
Binary mask (cast to float), 0 for the excluded indices, 1 for the rest
compute_cavity_from_marginals
Compute cavity from marginals :param marginals: list of tensors
compute_cavity
The cavity distributions for all data points. This corresponds to the marginal distribution qᐠⁿ(fₙ) of qᐠⁿ(f) = q(f)/tₙ(fₙ)ᵃ
compute_log_norm
Compute log normalizer
update_sites
Compute the site updates and perform one update step :param site_indices: list of indices to be updated
elbo
Computes the marginal log marginal likelihood of the approximate joint p(s, y)
energy
PEP Energy
predict_log_density
Compute the log density of the data at the new data points.
input_data – A tuple of time points and observations containing the data at which to calculate the loss for training the model: a tensor of inputs with shape batch_shape + [num_data], a tensor of observations with shape batch_shape + [num_data, observation_dim].
full_output_cov – Either full output covariance (True) or marginal variances (False).
True
False
gradient_correction
Transforms vectors g=[g1,g2] and i=[i1,i2] into h=[h1, h2] where h2 = 1/2 * 1/(i2 + 1/g2) and h1 = 2 * h2 * (g1/g2 - i1)
inputs – a tensor of inputs with shape batch_shape + [num_data],
grads – a tensor of gradients with shape batch_shape + [num_data],
a tensor of modified gradients with shape batch_shape + [num_data],