R/distributions.R
tfd_joint_distribution_sequential_auto_batched.RdThis class provides automatic vectorization and alternative semantics for
tfd_joint_distribution_sequential(), which in many cases allows for
simplifications in the model specification.
tfd_joint_distribution_sequential_auto_batched( model, batch_ndims = 0, use_vectorized_map = TRUE, validate_args = FALSE, name = NULL )
| model | A generator that yields a sequence of |
|---|---|
| batch_ndims |
|
| use_vectorized_map |
|
| validate_args | Logical, default FALSE. When TRUE distribution parameters are checked for validity despite possibly degrading runtime performance. When FALSE invalid inputs may silently render incorrect outputs. Default value: FALSE. |
| name | name prefixed to Ops created by this class. |
a distribution instance.
Automatic vectorization
Auto-vectorized variants of JointDistribution allow the user to avoid
explicitly annotating a model's vectorization semantics.
When using manually-vectorized joint distributions, each operation in the
model must account for the possibility of batch dimensions in Distributions
and their samples. By contrast, auto-vectorized models need only describe
a single sample from the joint distribution; any batch evaluation is
automated using tf$vectorized_map as required. In many cases this
allows for significant simplications. For example, the following
manually-vectorized tfd_joint_distribution_sequential() model:
model <- tfd_joint_distribution_sequential( list( tfd_normal(loc = 0, scale = tf$ones(3L)), tfd_normal(loc = 0, scale = 1), function(y, x) { tfd_normal(loc = x[reticulate::py_ellipsis(), 1:2] + y[reticulate::py_ellipsis(), tf$newaxis], scale = 1) } ) )
can be written in auto-vectorized form as
model <- tfd_joint_distribution_sequential_auto_batched( list( tfd_normal(loc = 0, scale = tf$ones(3L)), tfd_normal(loc = 0, scale = 1), function(y, x) {tfd_normal(loc = x[1:2] + y, scale = 1)} ) )
in which we were able to avoid explicitly accounting for batch dimensions
when indexing and slicing computed quantities in the third line.
Note: auto-vectorization is still experimental and some TensorFlow ops may
be unsupported. It can be disabled by setting use_vectorized_map=FALSE.
Alternative batch semantics
This class also provides alternative semantics for specifying a batch of
independent (non-identical) joint distributions.
Instead of simply summing the log_probs of component distributions
(which may have different shapes), it first reduces the component log_probs
to ensure that jd$log_prob(jd$sample()) always returns a scalar, unless
batch_ndims is explicitly set to a nonzero value (in which case the result
will have the corresponding tensor rank).
The essential changes are:
An event of JointDistributionSequentialAutoBatched is the list of
tensors produced by $sample(); thus, the event_shape is the
list containing the shapes of sampled tensors. These combine both
the event and batch dimensions of the component distributions. By contrast,
the event shape of a base JointDistributions does not include batch
dimensions of component distributions.
The batch_shape is a global property of the entire model, rather
than a per-component property as in base JointDistributions.
The global batch shape must be a prefix of the batch shapes of
each component; the length of this prefix is specified by an optional
argument batch_ndims. If batch_ndims is not specified, the model has
batch shape ().#'
For usage examples see e.g. tfd_sample(), tfd_log_prob(), tfd_mean().
Other distributions:
tfd_autoregressive(),
tfd_batch_reshape(),
tfd_bates(),
tfd_bernoulli(),
tfd_beta_binomial(),
tfd_beta(),
tfd_binomial(),
tfd_categorical(),
tfd_cauchy(),
tfd_chi2(),
tfd_chi(),
tfd_cholesky_lkj(),
tfd_continuous_bernoulli(),
tfd_deterministic(),
tfd_dirichlet_multinomial(),
tfd_dirichlet(),
tfd_empirical(),
tfd_exp_gamma(),
tfd_exp_inverse_gamma(),
tfd_exponential(),
tfd_gamma_gamma(),
tfd_gamma(),
tfd_gaussian_process_regression_model(),
tfd_gaussian_process(),
tfd_generalized_normal(),
tfd_geometric(),
tfd_gumbel(),
tfd_half_cauchy(),
tfd_half_normal(),
tfd_hidden_markov_model(),
tfd_horseshoe(),
tfd_independent(),
tfd_inverse_gamma(),
tfd_inverse_gaussian(),
tfd_johnson_s_u(),
tfd_joint_distribution_named_auto_batched(),
tfd_joint_distribution_named(),
tfd_joint_distribution_sequential(),
tfd_kumaraswamy(),
tfd_laplace(),
tfd_linear_gaussian_state_space_model(),
tfd_lkj(),
tfd_log_logistic(),
tfd_log_normal(),
tfd_logistic(),
tfd_mixture_same_family(),
tfd_mixture(),
tfd_multinomial(),
tfd_multivariate_normal_diag_plus_low_rank(),
tfd_multivariate_normal_diag(),
tfd_multivariate_normal_full_covariance(),
tfd_multivariate_normal_linear_operator(),
tfd_multivariate_normal_tri_l(),
tfd_multivariate_student_t_linear_operator(),
tfd_negative_binomial(),
tfd_normal(),
tfd_one_hot_categorical(),
tfd_pareto(),
tfd_pixel_cnn(),
tfd_poisson_log_normal_quadrature_compound(),
tfd_poisson(),
tfd_power_spherical(),
tfd_probit_bernoulli(),
tfd_quantized(),
tfd_relaxed_bernoulli(),
tfd_relaxed_one_hot_categorical(),
tfd_sample_distribution(),
tfd_sinh_arcsinh(),
tfd_skellam(),
tfd_spherical_uniform(),
tfd_student_t_process(),
tfd_student_t(),
tfd_transformed_distribution(),
tfd_triangular(),
tfd_truncated_cauchy(),
tfd_truncated_normal(),
tfd_uniform(),
tfd_variational_gaussian_process(),
tfd_vector_diffeomixture(),
tfd_vector_exponential_diag(),
tfd_vector_exponential_linear_operator(),
tfd_vector_laplace_diag(),
tfd_vector_laplace_linear_operator(),
tfd_vector_sinh_arcsinh_diag(),
tfd_von_mises_fisher(),
tfd_von_mises(),
tfd_weibull(),
tfd_wishart_linear_operator(),
tfd_wishart_tri_l(),
tfd_wishart(),
tfd_zipf()