markovflow.conditionals
Module for evaluating conditional distributions.
conditional_predict
Given \(∀ xₜ ∈\) new_time_points, compute the means and covariances of the marginal densities:
new_time_points
Or, if \(Sₜ\) is not given, compute the conditional density:
Note
new_time_points and training_time_points must be sorted.
training_time_points
Where:
\(p\) is the density over state trajectories specified by the kernel \(∀ xₜ ∈\) new_time_points: \[\begin{split}x₊ = arg minₓ \{|x-xₜ|, x ∈ \verb|training_time_point and |x>xₜ\}\\ x₋ = arg minₓ \{|x-xₜ|, x ∈ \verb|training_time_point and |x⩽xₜ\}\end{split}\]
\(p\) is the density over state trajectories specified by the kernel
\(∀ xₜ ∈\) new_time_points:
Details of the computation of \(Pₜ\) and \(Tₜ\) are found in conditional_statistics().
conditional_statistics()
new_time_points – Sorted time points to generate observations for, with shape batch_shape + [num_new_time_points,].
batch_shape + [num_new_time_points,]
training_time_points – Sorted time points to condition on, with shape batch_shape + [num_training_time_points,].
batch_shape + [num_training_time_points,]
kernel – A kernel.
training_pairwise_means – Pairs of states to condition on, with shape batch_shape + [num_training_time_points, 2 * state_dim].
batch_shape + [num_training_time_points, 2 * state_dim]
training_pairwise_covariances – Covariances of the pairs of states to condition on, with shape batch_shape + [num_training_time_points, 2 * state_dim, 2 * state_dim].
batch_shape + [num_training_time_points, 2 * state_dim, 2 * state_dim]
Predicted mean and covariance for the new time points, with respective shapes batch_shape + [num_new_time_points, state_dim] batch_shape + [num_new_time_points, state_dim, state_dim].
batch_shape + [num_new_time_points, state_dim]
batch_shape + [num_new_time_points, state_dim, state_dim]
conditional_statistics
Given \(∀ xₜ ∈\) new_time_points, compute the statistics \(Pₜ\) and \(Tₜ\) of the conditional densities:
…where:
\(p\) is the density over state trajectories specified by the kernel \(∀ xₜ ∈\) new_time_points: \[\begin{split}x₊ = arg minₓ \{ |x-xₜ|, x ∈ \verb|training_time_point and |x>xₜ \}\\ x₋ = arg minₓ \{ |x-xₜ|, x ∈ \verb|training_time_point and |x⩽xₜ \}\end{split}\]
Parameters for the conditional mean and covariance, with respective shapes batch_shape + [num_new_time_points, state_dim, 2 * state_dim] batch_shape + [num_new_time_points, state_dim, state_dim].
batch_shape + [num_new_time_points, state_dim, 2 * state_dim]
_conditional_statistics_from_transitions
Given consecutive time differences Δ₁ = t - t₋ and Δ₂ = t₊ - t of ordered triplets t₋ < t < t₊, we denote their values as x₋, xₜ, x₊ and their conditional distributions as p(x₊ | xₜ) = 𝓝(x₊; Aₜ₊xₜ, Qₜ₊) where [Aₜ₊ == A_tp, Qₜ₊ == Q_tp] p(xₜ | x₋) = 𝓝(xₜ; A₋ₜx₋, Q₋ₜ) where [A₋ₜ == A_mt, Q₋ₜ == Q_mt]
This computes Dₜ, Eₜ, Tₜ (or Tₜ⁻¹) such that p(xₜ | x₋, x₊) = 𝓝(xₜ; Dₜ @ x₋ + Eₜ @ x₊, Tₜ)
p(x₊|xₜ, x₋) = p(x₊|xₜ) = 𝓝(Aₜ₊xₜ, Qₜₚ) p(xₜ|x₋) = 𝓝(A₋ₜx₋, Q₋ₜ) p(x₊| x₋) = 𝓝(A₋₊x₋, Q₋₊ = Qₜ₊ + Aₜ₊Q₋ₜAₜ₊ᵀ)
[ Aₜ₊Q₋ₜ, Q₋₊ ]]
Given this joint distribution we can obtain the mean and covariance of the conditional distribution of p(xₜ|[x₋, x₊]) = 𝓝(xₜ; A₋ₜx₋ + Q₋ₜAₜ₊ᵀQ₋₊⁻¹(x₊ - A₋₊x₋), Q₋ₜ - Q₋ₜAₜ₊ᵀQ₋₊⁻¹Aₜ₊Q₋ₜ)
= 𝓝(xₜ; (A₋ₜ - Q₋ₜAₜ₊ᵀQ₋₊⁻¹A₋₊)x₋ + Q₋ₜAₜ₊ᵀQ₋₊⁻¹x₊,(Q₋ₜ⁻¹ + Aₜ₊ᵀQₜ₊⁻¹Aₜ₊)⁻¹)
(Q₋ₜ⁻¹ + Aₜ₊ᵀQₜ₊⁻¹Aₜ₊)⁻¹)
state_transitions_to_t – the state transitions from t₋ to t - A₋ₜ batch_shape + [num_time_points, state_dim, state_dim]
batch_shape + [num_time_points, state_dim, state_dim]
process_covariances_to_t – the process covariances from t₋ to t - Q₋ₜ batch_shape + [num_time_points, state_dim, state_dim]
state_transitions_from_t – the state transitions from t to t₊ - A₋ₜ batch_shape + [num_time_points, state_dim, state_dim]
process_covariances_from_t – the process covariances from t to t₊ - Qₜ₊ batch_shape + [num_time_points, state_dim, state_dim]
return_precision – bool, defaults to False. if True (resp. False), conditional precision (resp. covariance) is returned
parameters for the conditional mean and covariance batch_shape + [num_time_points, state_dim, state_dim] batch_shape + [num_time_points, state_dim, state_dim] batch_shape + [num_time_points, state_dim, state_dim]
_conditional_statistics
p(xₜ|x₋, x₊) = 𝓝(Pₜ @ [x₋, x₊], Tₜ)
p is the density over state trajectories specified by the kernel
x₊ = arg minₓ { |x-xₜ|, x ∈ training_time_point and x>xₜ } x₋ = arg minₓ { |x-xₜ|, x ∈ training_time_point and x⩽xₜ }
Warning: new_time_points and training_time_points must be sorted
new_time_points – sorted time points to generate observations for batch_shape + [num_new_time_points,]
training_time_points – sorted time points to condition on batch_shape + [num_training_time_points,]
kernel – a Markovian Kernel
parameters for the conditional mean and covariance, and the insertion indices batch_shape + [num_new_time_points, state_dim, 2*state_dim] batch_shape + [num_new_time_points, state_dim, state_dim] batch_shape + [num_new_time_points,]
batch_shape + [num_new_time_points, state_dim, 2*state_dim]
cyclic_reduction_conditional_statistics
Compute \(Fₜ, Gₜ, Lₜ\). Such that:
…where superscripts \(e\) and \(c\) refer to explained and conditioning respectively.
\(xᵉ\) and \(xᶜ\) must be sorted, such that:
…where \(len(xᵉ) = len(xᶜ)\) or \(len(xᵉ) = len(xᶜ) + 1\).
This computes the conditional statistics \(Fₜ, Gₜ, Lₜ\), where \(𝔼 xᵉ|xᶜ = - L⁻ᵀ Uᵀ xᶜ\), with:
Uᵀ = | F₁ᵀ [ |] and L⁻ᵀ = |L₁⁻ᵀ | | G₁ᵀ, F₂ᵀ [ |] | L₂⁻ᵀ | | , G₂ᵀ,⋱ [ |] | L₃⁻ᵀ | | ⋱ ⋱ [ |] | ⋱ | | ⋱ Fₙ₋₁ᵀ [ |] | ⋱ | | Gₙ₋₁ᵀ [ Fₙᵀ |] | Lₙ⁻ᵀ |
\(L\) is the Cholesky factor of the conditional precision \(xᵉ|xᶜ\).
Statistics \(F\) and \(G\) are computed from the conditional mean projection parameters \(D\) and \(E\) defined by \(𝔼 xᵉₙ|xᶜ = Dₙ @ xᶜₙ₋₁ + Eₙ @ xᶜₙ\).
Solving the system \(- (L⁻ᵀ Uᵀ xᶜ)ₙ = Dₙ @ xᶜₙ₋₁ + Eₙ @ xᶜₙ\), we get \(Gₙ₋₁ᵀ = -Lₙᵀ Dₙ\) and \(Fₙᵀ = -Lₙᵀ Eₙ\).
Details of the system:
-| L₁⁻ᵀF₁ᵀ xᶜ₁ | = | E₁ xᶜ₁ | L₂⁻ᵀG₁ᵀxᶜ₁ + L₂⁻ᵀ F₂ᵀ xᶜ₂ | | D₂ xᶜ₁ + E₂ xᶜ₂ | L₃⁻ᵀ G₂ᵀ xᶜ₂ , L₃⁻ᵀ F₃ᵀ xᶜ₃, | | D₃ xᶜ₂ + E₃ xᶜ₃ | ⋮ | | ⋮ | Lₙ⁻ᵀ Gₙ₋₁ᵀxᶜₙ₋₁, Lₙ⁻ᵀ [Fₙᵀ]xᶜₙ| | Dₙ xᶜₙ₋₁ + [Eₙ] xᶜₙ
Remarks on size:
When splitting \(x\) of size \(n\) into odd and even, you get \(nᵉ = (n+1)//2\) and \(nᶜ = n//2\)
At each level, cyclic reduction statistics have shape:
Note that:
\(F₀\) is not defined (there is no time point below \(xᵉ₀\)) The last element \(G\) may not be defined if \(len(xᵉ) = len(xᶜ)\)
\(F₀\) is not defined (there is no time point below \(xᵉ₀\))
The last element \(G\) may not be defined if \(len(xᵉ) = len(xᶜ)\)
explained_time_points – Sorted time points to generate observations for, with shape batch_shape + [num_time_points_1,].
batch_shape + [num_time_points_1,]
conditioning_time_points – Sorted time points to condition on, with shape batch_shape + [num_time_points_2,].
batch_shape + [num_time_points_2,]
Parameters for the conditional, with respective shapes batch_shape + [num_conditioning, state_dim, state_dim] batch_shape + [num_explained - 1, state_dim, state_dim] batch_shape + [num_explained, state_dim, state_dim].
batch_shape + [num_conditioning, state_dim, state_dim]
batch_shape + [num_explained - 1, state_dim, state_dim]
batch_shape + [num_explained, state_dim, state_dim]
base_conditional_predict
Predict state at new time points given conditional statistics.
Given conditionals statistics \(Pₜ, Tₜ\) of \(p(xₜ|x₋, x₊) = 𝓝(Pₜ @ [x₋, x₊], Tₜ)\) and pairwise marginals \(p(xₜ₋, xₜ₊) = 𝓝(mₜ, Sₜ)\), compute marginal mean and covariance of the marginal density:
If \(Sₜ\) is not provided, compute the conditional mean and covariance of the conditional density:
conditional_projections – \(Pₜ\) with shape batch_shape + [num_time_points, state_dim].
batch_shape + [num_time_points, state_dim]
conditional_covariances – \(Tₜ\) with shape batch_shape + [num_time_points, state_dim, state_dim].
adjacent_states – Pairs of states to condition on, with shape batch_shape + [num_time_points, 2 * state_dim].
batch_shape + [num_time_points, 2 * state_dim]
pairwise_state_covariances – Covariances of the pairs of states to condition on, with shape batch_shape + [num_time_points, 2 * state_dim, 2 * state_dim].
batch_shape + [num_time_points, 2 * state_dim, 2 * state_dim]
Predicted mean and covariance for the time points, with respective shapes batch_shape + [num_time_points, state_dim] batch_shape + [num_time_points, state_dim, state_dim].
pairwise_marginals
TODO(sam): figure out what the initial mean and covariance should be for non-stationary kernels
For each pair of subsequent states \(xₖ, xₖ₊₁\), return the mean and covariance of their joint distribution. This is assuming we start from, and revert to, the prior:
Then:
dist – The distribution.
initial_mean – The prior mean (used to extend the pairwise marginals of the distribution).
initial_covariance – The prior covariance (used to extend the pairwise marginal of the state space model).
Mean and covariance pairs for the marginals, with respective shapes batch_shape + [num_transitions + 2, state_dim] batch_shape + [num_transitions + 2, state_dim, state_dim].
batch_shape + [num_transitions + 2, state_dim]
batch_shape + [num_transitions + 2, state_dim, state_dim]