markovflow.mean_function
Module containing mean functions.
MeanFunction
Bases: tf.Module, abc.ABC
tf.Module
abc.ABC
Abstract class for mean functions.
Represents the action \(u(t)\) added to the latent states:
…resulting in the the mean function:
We can then solve the pure SDE:
…where:
This class provides a very general interface for the function \(μ(t)\).
Note
Implementations of this class should typically avoid performing computation in their __init__ method. Performing computation in the constructor conflicts with running in TensorFlow’s eager mode.
__init__
__call__
Return the mean function evaluated at the given time points.
time_points – A tensor with shape [... num_data].
[... num_data]
The mean function evaluated at the time points, with shape [... num_data, obs_dim].
[... num_data, obs_dim]
ZeroMeanFunction
Bases: MeanFunction
Represents a mean function that is zero everywhere.
obs_dim – The dimension of the zeros to output.
LinearMeanFunction
Represents a mean function that is linear. That is, where \(m(t) = a * t\).
coefficient – The linear coefficient.
obs_dim – The output dimension of the mean function.
The mean function evaluated at the time points with shape [... num_data, obs_dim].
ImpulseMeanFunction
Represents a mean function that is an impulse action. This is given by:
…in:
…and then:
Or:
If we let:
…then we can write this as a LowerTriangularBlockTriDiagonal equation:
LowerTriangularBlockTriDiagonal
[ I ] a₀ u₀ [-A₁, I ] a₁ u₁ [ -A₂, I ] a₂ = u₂ [ ᨞ ᨞ ] ⋮ ⋮ [ -Aₙ, I] aₙ uₙ
We can then determine the \(aₖ\) using a matrix solve.
The effect of the action is not seen until fractionally after it is applied. That is, if an impulse is applied at time \(t\), \(μ(t)\) will not see the effect but \(μ(t + ε)\) will.
action_times – The times at which actions occur, with shape [... num_actions].
[... num_actions]
state_perturbations – The magnitude of the impulse, with shape [... num_actions, state_dim].
[... num_actions, state_dim]
kernel – The kernel that is used to generate this mean function.
For each time point, we find the index of the function associated with it; that is, the closest previous impulse.
This index is then used to find the parameters of the function:
…where \(tₖ < t ≤ tₖ₊₁\).
The mean function evaluated at the time points, with shape [... num_data, state_dim].
[... num_data, state_dim]
StepMeanFunction
Represents a mean function that is a step action. This is given by:
Then:
…we can write this as a LowerTriangularBlockTriDiagonal equation:
[ I ] b₀ [a₋₁ - a₀] [-A₁, I ] b₁ [a₀ - a₁] [ -A₂, I ] b₂ = [a₁ - a₂] [ ᨞ ᨞ ] ⋮ ⋮ [ -Aₙ, I] bₙ [aₙ₋₁ - aₙ]
We can then determine the \(bₖ\) using a matrix solve.
state_perturbations – The magnitude of the impulse, with shape [... num_actions, obs_dim].
[... num_actions, obs_dim]
For each time point, we find the index of the function associated with it; that is, the closest previous step (call it \(k\)).