[1]:
import numpy as np import tensorflow as tf import matplotlib.pyplot as plt from gpflow.ci_utils import ci_niter from gpflow.likelihoods import Gaussian from markovflow.models import SparseVariationalGaussianProcess from markovflow.kernels import Matern12, Matern32 from markovflow.kernels.sde_kernel import IndependentMultiOutputStack from markovflow.ssm_natgrad import SSMNaturalGradient
2022-09-17 15:54:16.567297: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/hostedtoolcache/Python/3.7.13/x64/lib 2022-09-17 15:54:16.567334: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
This notebook is about stacked kernels, which is one way to get multiple outputs in MarkovFlow.
Stacked kernels use a leading ‘batch’ dimension to compute multiple kernels together. Conceptually, if a kernel matrix is of dimensions [N x N], then a stacked kernel produces an object of shape [S x N x N]. All of the markovflow computations will have this extra leading dimension. For example the state-transition matrices will be of shape [S, T, D, D], where T is the number of time points and D is the state dimension.
[N x N]
[S x N x N]
[S, T, D, D]
The data, however, are expected to be of shape [N x S], so the S dimension should follow, not lead. This convention makes the stacked kernel compatible with likelihoods that can handle multiple outputs and processes.
[N x S]
S
The advantage of this approach to multiple-outputs is that it is computationally efficient, because all computations can loop over this leading S dimension instead of augmenting the state dimension of the process. However, using a similar parameterization as an approximate posterior in an inference problem is a bit restrictive since it forces the posterior processes to be independent which may not always be an appropriate assumption.
You may also be interesed in another style of multiple output kernels in MarkovFlow, where the state-dimensions of the processes are concatenated. In that case, the computational complexity grows cubically with the number of outputs, since the state dimension is growing. See the factor_analysis notebook.
We’ll build a model with two outputs using a stacked kernel. We use the sparse GP object from markovflow to do inference.
[2]:
# constants num_data = 300 num_inducing = 50 num_outputs = 2 lengthscales = [0.05, 0.05]
[3]:
# construct a simple data set with correlated noise X = np.linspace(0, 1, num_data) X_tf = tf.broadcast_to(X, (num_outputs, num_data)) # duplicate F = np.hstack([np.sin(10 * X)[:, None], np.cos(15 * X)[:, None]]) Sigma = np.array([[0.1, 0.08], [0.08, 0.1]]) noise = np.random.multivariate_normal(np.zeros(2), Sigma, num_data) Y = F + noise data = (X_tf, tf.convert_to_tensor(Y))
2022-09-17 15:54:18.232026: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set 2022-09-17 15:54:18.232198: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/hostedtoolcache/Python/3.7.13/x64/lib 2022-09-17 15:54:18.232209: W tensorflow/stream_executor/cuda/cuda_driver.cc:326] failed call to cuInit: UNKNOWN ERROR (303) 2022-09-17 15:54:18.232232: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (fv-az178-774): /proc/driver/nvidia/version does not exist 2022-09-17 15:54:18.232476: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F FMA To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags. 2022-09-17 15:54:18.232597: I tensorflow/compiler/jit/xla_gpu_device.cc:99] Not creating XLA devices, tf_xla_enable_xla_devices not set
[4]:
# constuct a stacked kernel with two outputs k1 = Matern12(lengthscale=lengthscales[0], variance=1.) k2 = Matern32(lengthscale=lengthscales[1], variance=1.) kern = IndependentMultiOutputStack([k1, k2], jitter=1e-6) # construct a model lik = Gaussian(1.) Z = tf.broadcast_to(np.linspace(0, 1, num_inducing), (num_outputs, num_inducing)) m = SparseVariationalGaussianProcess(kern, lik, Z)
WARNING:tensorflow:From /home/runner/work/markovflow/markovflow/.venv/lib/python3.7/site-packages/tensorflow/python/ops/linalg/linear_operator_block_diag.py:223: LinearOperator.graph_parents (from tensorflow.python.ops.linalg.linear_operator) is deprecated and will be removed in a future version. Instructions for updating: Do not call `graph_parents`.
[5]:
def plot(): # plot the data with predictions: p = m.posterior mu, var = p.predict_y(X_tf) for i in [0, 1]: plt.plot(X, Y[:, i], f'C{i}x', alpha=0.5) plt.plot(X, mu[:, i], f'C{i}') std = tf.math.sqrt(var[:, i]) plt.plot(X, mu[:, i] + 2 * std, f'C{i}--') plt.plot(X, mu[:, i] - 2 * std, f'C{i}--') plot() _ = plt.title('The model fit before optimization')
[6]:
def optimize(model): # we'll use the natural gradient optimizer for the variational parameters and the Adam optimizer for hyper-parameters opt_ng = SSMNaturalGradient(.5) opt_adam = tf.optimizers.Adam(0.05) @tf.function def step(): opt_adam.minimize(lambda : -model.elbo(data), model._likelihood.trainable_variables + model._kernel.trainable_variables) opt_ng.minimize(lambda : -model.elbo(data), ssm=m.dist_q) @tf.function def elbo(): return m.elbo(data) max_iter = ci_niter(400) for i in range(max_iter): step() if i % 50 == 0: print(f"Iteration {i}, elbo:{elbo().numpy():.4}")
[7]:
optimize(m)
2022-09-17 15:54:19.280732: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
WARNING:tensorflow:From /home/runner/work/markovflow/markovflow/.venv/lib/python3.7/site-packages/tensorflow/python/ops/linalg/linear_operator_full_matrix.py:158: calling LinearOperator.__init__ (from tensorflow.python.ops.linalg.linear_operator) with graph_parents is deprecated and will be removed in a future version. Instructions for updating: Do not pass `graph_parents`. They will no longer be used. WARNING:tensorflow:From /home/runner/work/markovflow/markovflow/.venv/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py:605: calling map_fn_v2 (from tensorflow.python.ops.map_fn) with dtype is deprecated and will be removed in a future version. Instructions for updating: Use fn_output_signature instead
2022-09-17 15:54:35.257400: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2) 2022-09-17 15:54:35.257856: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 2793435000 Hz 2022-09-17 15:54:50.812413: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:906] Skipping loop optimization for Merge node with control input: assert_equal_1/Assert/AssertGuard/branch_executed/_9 2022-09-17 15:54:57.164086: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:906] Skipping loop optimization for Merge node with control input: assert_equal_59/Assert/AssertGuard/branch_executed/_94
Iteration 0, elbo:-994.6 Iteration 50, elbo:-522.1 Iteration 100, elbo:-311.5 Iteration 150, elbo:-253.8 Iteration 200, elbo:-239.8 Iteration 250, elbo:-235.0 Iteration 300, elbo:-231.9 Iteration 350, elbo:-229.1
[8]:
plot() _ = plt.title('The model fit after optimization')
Since we’re using a leading dimension in the stacked kernel, we might be worried about whether this impedes markovflow’s ability to fit a model to multiple independent datasets. Fear not! extra leading dimensions are still handled (and looped over appropriately), and parameter sharing of the kernels (between datasets, not outputs) still happens smoothly.
In this example, we’ll fit a model with a heteroskedastic likelihood to multiple datasets simultaneously. The likelihood requires two GP outputs to model a single data column. One of the GPs models the mean of the data, and the other models the variance. We’ll generate multiple datasets, construct the outline of a very simple likelihood and fit the whole shebang in a single markovflow model. Each dataset gets its own GPs, but the kernels paraemters are shared amongst datasets.
[9]:
num_data = 300 num_datasets = 2 num_inducing = 30 num_gp_outputs = 2 num_data_outputs = 1 lengthscales = [0.05, 0.5]
[10]:
# generate datasets from sinusoidal means and time varying noise variances (exponentiated sinusoids) Xs, Ys = [], [] for d in range(num_datasets): X = np.linspace(0, 1, num_data) amplitudes = np.random.rand(2) * np.array([1, 0.5]) + np.array([3, 2]) phases = np.random.randn(2) * 2 * np.pi frequencies = np.array([10, 2]) f1, f2 = [np.sin(2*np.pi * X * omega + phi) * a for omega, phi, a in zip(frequencies, phases, amplitudes)] Y = f1 + np.random.randn(*f2.shape) * np.exp(0.5 * f2) Ys.append(Y.reshape(num_data, num_data_outputs)) Xs.append(tf.broadcast_to(X, (num_gp_outputs, num_data))) Xs = tf.convert_to_tensor(Xs) # [num_datasets, num_gps, num_data] Ys = tf.convert_to_tensor(Ys) # [num_datasets, num_data, num_data_outputs] data = (Xs, Ys)
[11]:
from markovflow.likelihoods import Likelihood import tensorflow_probability as tfp class HetGaussian(Likelihood): def log_probability_density(self, f, y): mu, logvar = f[..., 0], f[..., 1] return tfp.distributions.Normal(loc=mu, scale=tf.exp(0.5 * logvar)).log_p(y[..., 0]) def variational_expectations(self, f_means, f_covariances, observations): f1, f2 = f_means[..., 0], f_means[..., 1] variances = f_covariances # assume independent GPs v1, v2 = variances[..., 0], variances[..., 1] return -0.5 * (np.log(2*np.pi) + f2 + tf.exp(-f2 + 0.5 * v2) * (tf.square(f1 - observations[..., 0]) + v1)) def predict_density(self, f_means, f_covariances, observations): raise NotImplementedError def predict_mean_and_var(self, f_means, f_covariances): raise NotImplementedError
[12]:
# constuct a stacked kernel with two outputs k1 = Matern32(lengthscale=.05, variance=1.) k2 = Matern12(lengthscale=.5, variance=1.) kern = IndependentMultiOutputStack([k1, k2]) # construct a model lik = HetGaussian() Z = tf.broadcast_to(np.linspace(0, 1, num_inducing), (num_datasets, num_gp_outputs, num_inducing)) m = SparseVariationalGaussianProcess(kern, lik, Z) print(m.elbo(data))
tf.Tensor(-5570.847114316649, shape=(), dtype=float64)
[13]:
from gpflow.optimizers import Scipy opt = Scipy() opt.minimize(lambda: -m.elbo(data), m.trainable_variables, options=dict(maxiter=ci_niter(1000)))
2022-09-17 15:55:10.079404: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:906] Skipping loop optimization for Merge node with control input: assert_equal_1/Assert/AssertGuard/branch_executed/_9
fun: 1290.3436786177656 hess_inv: <1068x1068 LbfgsInvHessProduct with dtype=float64> jac: array([ 1.42880514, -0.94413704, 0.89699324, ..., -1.10693675, -0.36922019, 0.29900939]) message: 'STOP: TOTAL NO. of ITERATIONS REACHED LIMIT' nfev: 1049 nit: 1000 njev: 1049 status: 1 success: False x: array([-1.43898835e-01, -1.33500294e-03, -1.31227604e+01, ..., 8.95741743e+00, 1.46095999e+00, 6.95882747e+00])
[14]:
mus, _ = m.posterior.predict_f(Xs) f, axes = plt.subplots(num_datasets, 1, sharex=True, sharey=True, figsize=(8, 6)) for i, (Y, ax, mu) in enumerate(zip(Ys, axes, mus)): ax.plot(Y, 'C0.', alpha=0.3) ax.plot(mu[:, 0], 'C1') ax.plot(mu[:, 0] + 2 * tf.exp(0.5 * mu[:, 1]), 'C1--') ax.plot(mu[:, 0] - 2 * tf.exp(0.5 * mu[:, 1]), 'C1--') ax.set_title(f'dataset {i}')