R/sts.R
sts_additive_state_space_model.Rd
A state space model (SSM) posits a set of latent (unobserved) variables that
evolve over time with dynamics specified by a probabilistic transition model
p(z[t+1] | z[t])
. At each timestep, we observe a value sampled from an
observation model conditioned on the current state, p(x[t] | z[t])
. The
special case where both the transition and observation models are Gaussians
with mean specified as a linear function of the inputs, is known as a linear
Gaussian state space model and supports tractable exact probabilistic
calculations; see tfd_linear_gaussian_state_space_model
for details.
sts_additive_state_space_model( component_ssms, constant_offset = 0, observation_noise_scale = NULL, initial_state_prior = NULL, initial_step = 0, validate_args = FALSE, allow_nan_stats = TRUE, name = NULL )
component_ssms |
|
---|---|
constant_offset | scalar |
observation_noise_scale | Optional scalar |
initial_state_prior | instance of |
initial_step | Optional scalar |
validate_args |
|
allow_nan_stats |
|
name | string prefixed to ops created by this class. Default value: "AdditiveStateSpaceModel". |
an instance of LinearGaussianStateSpaceModel
.
The sts_additive_state_space_model
represents a sum of component state space
models. Each of the N
components describes a random process
generating a distribution on observed time series x1[t], x2[t], ..., xN[t]
.
The additive model represents the sum of these
processes, y[t] = x1[t] + x2[t] + ... + xN[t] + eps[t]
, where
eps[t] ~ N(0, observation_noise_scale)
is an observation noise term.
Mathematical Details
The additive model concatenates the latent states of its component models. The generative process runs each component's dynamics in its own subspace of latent space, and then observes the sum of the observation models from the components.
Formally, the transition model is linear Gaussian:
p(z[t+1] | z[t]) ~ Normal(loc = transition_matrix.matmul(z[t]), cov = transition_cov)
where each z[t]
is a latent state vector concatenating the component
state vectors, z[t] = [z1[t], z2[t], ..., zN[t]]
, so it has size
latent_size = sum([c.latent_size for c in components])
.
The transition matrix is the block-diagonal composition of transition matrices from the component processes:
transition_matrix = [[ c0.transition_matrix, 0., ..., 0. ], [ 0., c1.transition_matrix, ..., 0. ], [ ... ... ... ], [ 0., 0., ..., cN.transition_matrix ]]
and the noise covariance is similarly the block-diagonal composition of component noise covariances:
transition_cov = [[ c0.transition_cov, 0., ..., 0. ], [ 0., c1.transition_cov, ..., 0. ], [ ... ... ... ], [ 0., 0., ..., cN.transition_cov ]]
The observation model is also linear Gaussian,
p(y[t] | z[t]) ~ Normal(loc = observation_matrix.matmul(z[t]), stddev = observation_noise_scale)
This implementation assumes scalar observations, so observation_matrix
has shape [1, latent_size]
.
The additive observation matrix simply concatenates the observation matrices from each component:
observation_matrix = concat([c0.obs_matrix, c1.obs_matrix, ..., cN.obs_matrix], axis=-1)
The effect is that each component observation matrix acts on the dimensions of latent state corresponding to that component, and the overall expected observation is the sum of the expected observations from each component.
If observation_noise_scale
is not explicitly specified, it is also computed
by summing the noise variances of the component processes:
observation_noise_scale = sqrt(sum([c.observation_noise_scale**2 for c in components]))
Other sts:
sts_autoregressive_state_space_model()
,
sts_autoregressive()
,
sts_constrained_seasonal_state_space_model()
,
sts_dynamic_linear_regression_state_space_model()
,
sts_dynamic_linear_regression()
,
sts_linear_regression()
,
sts_local_level_state_space_model()
,
sts_local_level()
,
sts_local_linear_trend_state_space_model()
,
sts_local_linear_trend()
,
sts_seasonal_state_space_model()
,
sts_seasonal()
,
sts_semi_local_linear_trend_state_space_model()
,
sts_semi_local_linear_trend()
,
sts_smooth_seasonal_state_space_model()
,
sts_smooth_seasonal()
,
sts_sparse_linear_regression()
,
sts_sum()