R/distributions.R
tfd_joint_distribution_named_auto_batched.Rd
This class provides automatic vectorization and alternative semantics for
tfd_joint_distribution_named()
, which in many cases allows for
simplifications in the model specification.
tfd_joint_distribution_named_auto_batched( model, batch_ndims = 0, use_vectorized_map = TRUE, validate_args = FALSE, name = NULL )
model  A generator that yields a sequence of 

batch_ndims 

use_vectorized_map 

validate_args  Logical, default FALSE. When TRUE distribution parameters are checked for validity despite possibly degrading runtime performance. When FALSE invalid inputs may silently render incorrect outputs. Default value: FALSE. 
name  name prefixed to Ops created by this class. 
a distribution instance.
Automatic vectorization
Autovectorized variants of JointDistribution allow the user to avoid
explicitly annotating a model's vectorization semantics.
When using manuallyvectorized joint distributions, each operation in the
model must account for the possibility of batch dimensions in Distributions
and their samples. By contrast, autovectorized models need only describe
a single sample from the joint distribution; any batch evaluation is
automated using tf$vectorized_map
as required. In many cases this
allows for significant simplications. For example, the following
manuallyvectorized tfd_joint_distribution_named()
model:
model < tfd_joint_distribution_sequential( list( x = tfd_normal(loc = 0, scale = tf$ones(3L)), y = tfd_normal(loc = 0, scale = 1), z = function(y, x) { tfd_normal(loc = x[reticulate::py_ellipsis(), 1:2] + y[reticulate::py_ellipsis(), tf$newaxis], scale = 1) } ) )
can be written in autovectorized form as
model < tfd_joint_distribution_sequential_auto_batched( list( x = tfd_normal(loc = 0, scale = tf$ones(3L)), y = tfd_normal(loc = 0, scale = 1), z = function(y, x) {tfd_normal(loc = x[1:2] + y, scale = 1)} ) )
in which we were able to avoid explicitly accounting for batch dimensions
when indexing and slicing computed quantities in the third line.
Note: autovectorization is still experimental and some TensorFlow ops may
be unsupported. It can be disabled by setting use_vectorized_map=FALSE
.
Alternative batch semantics
This class also provides alternative semantics for specifying a batch of
independent (nonidentical) joint distributions.
Instead of simply summing the log_prob
s of component distributions
(which may have different shapes), it first reduces the component log_prob
s
to ensure that jd$log_prob(jd$sample())
always returns a scalar, unless
batch_ndims
is explicitly set to a nonzero value (in which case the result
will have the corresponding tensor rank).
The essential changes are:
An event
of JointDistributionNamedAutoBatched
is the list of
tensors produced by $sample()
; thus, the event_shape
is the
list containing the shapes of sampled tensors. These combine both
the event and batch dimensions of the component distributions. By contrast,
the event shape of a base JointDistribution
s does not include batch
dimensions of component distributions.
The batch_shape
is a global property of the entire model, rather
than a percomponent property as in base JointDistribution
s.
The global batch shape must be a prefix of the batch shapes of
each component; the length of this prefix is specified by an optional
argument batch_ndims
. If batch_ndims
is not specified, the model has
batch shape ()
.#'
For usage examples see e.g. tfd_sample()
, tfd_log_prob()
, tfd_mean()
.
Other distributions:
tfd_autoregressive()
,
tfd_batch_reshape()
,
tfd_bates()
,
tfd_bernoulli()
,
tfd_beta_binomial()
,
tfd_beta()
,
tfd_binomial()
,
tfd_categorical()
,
tfd_cauchy()
,
tfd_chi2()
,
tfd_chi()
,
tfd_cholesky_lkj()
,
tfd_continuous_bernoulli()
,
tfd_deterministic()
,
tfd_dirichlet_multinomial()
,
tfd_dirichlet()
,
tfd_empirical()
,
tfd_exp_gamma()
,
tfd_exp_inverse_gamma()
,
tfd_exponential()
,
tfd_gamma_gamma()
,
tfd_gamma()
,
tfd_gaussian_process_regression_model()
,
tfd_gaussian_process()
,
tfd_generalized_normal()
,
tfd_geometric()
,
tfd_gumbel()
,
tfd_half_cauchy()
,
tfd_half_normal()
,
tfd_hidden_markov_model()
,
tfd_horseshoe()
,
tfd_independent()
,
tfd_inverse_gamma()
,
tfd_inverse_gaussian()
,
tfd_johnson_s_u()
,
tfd_joint_distribution_named()
,
tfd_joint_distribution_sequential_auto_batched()
,
tfd_joint_distribution_sequential()
,
tfd_kumaraswamy()
,
tfd_laplace()
,
tfd_linear_gaussian_state_space_model()
,
tfd_lkj()
,
tfd_log_logistic()
,
tfd_log_normal()
,
tfd_logistic()
,
tfd_mixture_same_family()
,
tfd_mixture()
,
tfd_multinomial()
,
tfd_multivariate_normal_diag_plus_low_rank()
,
tfd_multivariate_normal_diag()
,
tfd_multivariate_normal_full_covariance()
,
tfd_multivariate_normal_linear_operator()
,
tfd_multivariate_normal_tri_l()
,
tfd_multivariate_student_t_linear_operator()
,
tfd_negative_binomial()
,
tfd_normal()
,
tfd_one_hot_categorical()
,
tfd_pareto()
,
tfd_pixel_cnn()
,
tfd_poisson_log_normal_quadrature_compound()
,
tfd_poisson()
,
tfd_power_spherical()
,
tfd_probit_bernoulli()
,
tfd_quantized()
,
tfd_relaxed_bernoulli()
,
tfd_relaxed_one_hot_categorical()
,
tfd_sample_distribution()
,
tfd_sinh_arcsinh()
,
tfd_skellam()
,
tfd_spherical_uniform()
,
tfd_student_t_process()
,
tfd_student_t()
,
tfd_transformed_distribution()
,
tfd_triangular()
,
tfd_truncated_cauchy()
,
tfd_truncated_normal()
,
tfd_uniform()
,
tfd_variational_gaussian_process()
,
tfd_vector_diffeomixture()
,
tfd_vector_exponential_diag()
,
tfd_vector_exponential_linear_operator()
,
tfd_vector_laplace_diag()
,
tfd_vector_laplace_linear_operator()
,
tfd_vector_sinh_arcsinh_diag()
,
tfd_von_mises_fisher()
,
tfd_von_mises()
,
tfd_weibull()
,
tfd_wishart_linear_operator()
,
tfd_wishart_tri_l()
,
tfd_wishart()
,
tfd_zipf()