The affine autoregressive flow (Papamakarios et al., 2016) provides a relatively simple framework for user-specified (deep) architectures to learn a distribution over continuous events. Regarding terminology,
tfb_masked_autoregressive_flow( shift_and_log_scale_fn, is_constant_jacobian = FALSE, unroll_loop = FALSE, event_ndims = 1L, validate_args = FALSE, name = NULL )
shift_and_log_scale_fn | Function which computes shift and log_scale from both the
forward domain (x) and the inverse domain (y).
Calculation must respect the "autoregressive property". Suggested default:
tfb_masked_autoregressive_default_template(hidden_layers=...).
Typically the function contains |
---|---|
is_constant_jacobian | Logical, default: FALSE. When TRUE the implementation assumes log_scale does not depend on the forward domain (x) or inverse domain (y) values. (No validation is made; is_constant_jacobian=FALSE is always safe but possibly computationally inefficient.) |
unroll_loop | Logical indicating whether the |
event_ndims | integer, the intrinsic dimensionality of this bijector.
1 corresponds to a simple vector autoregressive bijector as implemented by the
|
validate_args | Logical, default FALSE. Whether to validate input with asserts. If validate_args is FALSE, and the inputs are invalid, correct behavior is not guaranteed. |
name | name prefixed to Ops created by this class. |
a bijector instance.
"Autoregressive models decompose the joint density as a product of conditionals, and model each conditional in turn. Normalizing flows transform a base density (e.g. a standard Gaussian) into the target density by an invertible transformation with tractable Jacobian." (Papamakarios et al., 2016)
In other words, the "autoregressive property" is equivalent to the
decomposition, p(x) = prod{ p(x[perm[i]] | x[perm[0:i]]) : i=0, ..., d }
where perm is some permutation of {0, ..., d}
. In the simple case where
the permutation is identity this reduces to:
p(x) = prod{ p(x[i] | x[0:i]) : i=0, ..., d }
. The provided
shift_and_log_scale_fn, tfb_masked_autoregressive_default_template, achieves
this property by zeroing out weights in its masked_dense layers.
In TensorFlow Probability, "normalizing flows" are implemented as
tfp.bijectors.Bijectors. The forward "autoregression" is implemented
using a tf.while_loop and a deep neural network (DNN) with masked weights
such that the autoregressive property is automatically met in the inverse.
A TransformedDistribution using MaskedAutoregressiveFlow(...) uses the
(expensive) forward-mode calculation to draw samples and the (cheap)
reverse-mode calculation to compute log-probabilities. Conversely, a
TransformedDistribution using Invert(MaskedAutoregressiveFlow(...)) uses
the (expensive) forward-mode calculation to compute log-probabilities and the
(cheap) reverse-mode calculation to compute samples.
Given a shift_and_log_scale_fn, the forward and inverse transformations are (a sequence of) affine transformations. A "valid" shift_and_log_scale_fn must compute each shift (aka loc or "mu" in Germain et al. (2015)]) and log(scale) (aka "alpha" in Germain et al. (2015)) such that ech are broadcastable with the arguments to forward and inverse, i.e., such that the calculations in forward, inverse below are possible.
For convenience, tfb_masked_autoregressive_default_template is offered as a possible shift_and_log_scale_fn function. It implements the MADE architecture (Germain et al., 2015). MADE is a feed-forward network that computes a shift and log(scale) using masked_dense layers in a deep neural network. Weights are masked to ensure the autoregressive property. It is possible that this architecture is suboptimal for your task. To build alternative networks, either change the arguments to tfb_masked_autoregressive_default_template, use the masked_dense function to roll-out your own, or use some other architecture, e.g., using tf.layers. Warning: no attempt is made to validate that the shift_and_log_scale_fn enforces the "autoregressive property".
Assuming shift_and_log_scale_fn has valid shape and autoregressive semantics, the forward transformation is
def forward(x): y = zeros_like(x) event_size = x.shape[-event_dims:].num_elements() for _ in range(event_size): shift, log_scale = shift_and_log_scale_fn(y) y = x * tf.exp(log_scale) + shift return y
and the inverse transformation is
def inverse(y): shift, log_scale = shift_and_log_scale_fn(y) return (y - shift) / tf.exp(log_scale)
Notice that the inverse does not need a for-loop. This is because in the forward pass each calculation of shift and log_scale is based on the y calculated so far (not x). In the inverse, the y is fully known, thus is equivalent to the scaling used in forward after event_size passes, i.e., the "last" y used to compute shift, log_scale. (Roughly speaking, this also proves the transform is bijective.)
For usage examples see tfb_forward()
, tfb_inverse()
, tfb_inverse_log_det_jacobian()
.
Other bijectors:
tfb_absolute_value()
,
tfb_affine_linear_operator()
,
tfb_affine_scalar()
,
tfb_affine()
,
tfb_ascending()
,
tfb_batch_normalization()
,
tfb_blockwise()
,
tfb_chain()
,
tfb_cholesky_outer_product()
,
tfb_cholesky_to_inv_cholesky()
,
tfb_correlation_cholesky()
,
tfb_cumsum()
,
tfb_discrete_cosine_transform()
,
tfb_expm1()
,
tfb_exp()
,
tfb_ffjord()
,
tfb_fill_scale_tri_l()
,
tfb_fill_triangular()
,
tfb_glow()
,
tfb_gompertz_cdf()
,
tfb_gumbel_cdf()
,
tfb_gumbel()
,
tfb_identity()
,
tfb_inline()
,
tfb_invert()
,
tfb_iterated_sigmoid_centered()
,
tfb_kumaraswamy_cdf()
,
tfb_kumaraswamy()
,
tfb_lambert_w_tail()
,
tfb_masked_autoregressive_default_template()
,
tfb_masked_dense()
,
tfb_matrix_inverse_tri_l()
,
tfb_matvec_lu()
,
tfb_normal_cdf()
,
tfb_ordered()
,
tfb_pad()
,
tfb_permute()
,
tfb_power_transform()
,
tfb_rational_quadratic_spline()
,
tfb_rayleigh_cdf()
,
tfb_real_nvp_default_template()
,
tfb_real_nvp()
,
tfb_reciprocal()
,
tfb_reshape()
,
tfb_scale_matvec_diag()
,
tfb_scale_matvec_linear_operator()
,
tfb_scale_matvec_lu()
,
tfb_scale_matvec_tri_l()
,
tfb_scale_tri_l()
,
tfb_scale()
,
tfb_shifted_gompertz_cdf()
,
tfb_shift()
,
tfb_sigmoid()
,
tfb_sinh_arcsinh()
,
tfb_sinh()
,
tfb_softmax_centered()
,
tfb_softplus()
,
tfb_softsign()
,
tfb_split()
,
tfb_square()
,
tfb_tanh()
,
tfb_transform_diagonal()
,
tfb_transpose()
,
tfb_weibull_cdf()
,
tfb_weibull()