R/distributions.R
tfd_joint_distribution_named_auto_batched.Rd
This class provides automatic vectorization and alternative semantics for
tfd_joint_distribution_named()
, which in many cases allows for
simplifications in the model specification.
tfd_joint_distribution_named_auto_batched( model, batch_ndims = 0, use_vectorized_map = TRUE, validate_args = FALSE, name = NULL )
model | A generator that yields a sequence of |
---|---|
batch_ndims |
|
use_vectorized_map |
|
validate_args | Logical, default FALSE. When TRUE distribution parameters are checked for validity despite possibly degrading runtime performance. When FALSE invalid inputs may silently render incorrect outputs. Default value: FALSE. |
name | name prefixed to Ops created by this class. |
a distribution instance.
Automatic vectorization
Auto-vectorized variants of JointDistribution allow the user to avoid
explicitly annotating a model's vectorization semantics.
When using manually-vectorized joint distributions, each operation in the
model must account for the possibility of batch dimensions in Distributions
and their samples. By contrast, auto-vectorized models need only describe
a single sample from the joint distribution; any batch evaluation is
automated using tf$vectorized_map
as required. In many cases this
allows for significant simplications. For example, the following
manually-vectorized tfd_joint_distribution_named()
model:
model <- tfd_joint_distribution_sequential( list( x = tfd_normal(loc = 0, scale = tf$ones(3L)), y = tfd_normal(loc = 0, scale = 1), z = function(y, x) { tfd_normal(loc = x[reticulate::py_ellipsis(), 1:2] + y[reticulate::py_ellipsis(), tf$newaxis], scale = 1) } ) )
can be written in auto-vectorized form as
model <- tfd_joint_distribution_sequential_auto_batched( list( x = tfd_normal(loc = 0, scale = tf$ones(3L)), y = tfd_normal(loc = 0, scale = 1), z = function(y, x) {tfd_normal(loc = x[1:2] + y, scale = 1)} ) )
in which we were able to avoid explicitly accounting for batch dimensions
when indexing and slicing computed quantities in the third line.
Note: auto-vectorization is still experimental and some TensorFlow ops may
be unsupported. It can be disabled by setting use_vectorized_map=FALSE
.
Alternative batch semantics
This class also provides alternative semantics for specifying a batch of
independent (non-identical) joint distributions.
Instead of simply summing the log_prob
s of component distributions
(which may have different shapes), it first reduces the component log_prob
s
to ensure that jd$log_prob(jd$sample())
always returns a scalar, unless
batch_ndims
is explicitly set to a nonzero value (in which case the result
will have the corresponding tensor rank).
The essential changes are:
An event
of JointDistributionNamedAutoBatched
is the list of
tensors produced by $sample()
; thus, the event_shape
is the
list containing the shapes of sampled tensors. These combine both
the event and batch dimensions of the component distributions. By contrast,
the event shape of a base JointDistribution
s does not include batch
dimensions of component distributions.
The batch_shape
is a global property of the entire model, rather
than a per-component property as in base JointDistribution
s.
The global batch shape must be a prefix of the batch shapes of
each component; the length of this prefix is specified by an optional
argument batch_ndims
. If batch_ndims
is not specified, the model has
batch shape ()
.#'
For usage examples see e.g. tfd_sample()
, tfd_log_prob()
, tfd_mean()
.
Other distributions:
tfd_autoregressive()
,
tfd_batch_reshape()
,
tfd_bates()
,
tfd_bernoulli()
,
tfd_beta_binomial()
,
tfd_beta()
,
tfd_binomial()
,
tfd_categorical()
,
tfd_cauchy()
,
tfd_chi2()
,
tfd_chi()
,
tfd_cholesky_lkj()
,
tfd_continuous_bernoulli()
,
tfd_deterministic()
,
tfd_dirichlet_multinomial()
,
tfd_dirichlet()
,
tfd_empirical()
,
tfd_exp_gamma()
,
tfd_exp_inverse_gamma()
,
tfd_exponential()
,
tfd_gamma_gamma()
,
tfd_gamma()
,
tfd_gaussian_process_regression_model()
,
tfd_gaussian_process()
,
tfd_generalized_normal()
,
tfd_geometric()
,
tfd_gumbel()
,
tfd_half_cauchy()
,
tfd_half_normal()
,
tfd_hidden_markov_model()
,
tfd_horseshoe()
,
tfd_independent()
,
tfd_inverse_gamma()
,
tfd_inverse_gaussian()
,
tfd_johnson_s_u()
,
tfd_joint_distribution_named()
,
tfd_joint_distribution_sequential_auto_batched()
,
tfd_joint_distribution_sequential()
,
tfd_kumaraswamy()
,
tfd_laplace()
,
tfd_linear_gaussian_state_space_model()
,
tfd_lkj()
,
tfd_log_logistic()
,
tfd_log_normal()
,
tfd_logistic()
,
tfd_mixture_same_family()
,
tfd_mixture()
,
tfd_multinomial()
,
tfd_multivariate_normal_diag_plus_low_rank()
,
tfd_multivariate_normal_diag()
,
tfd_multivariate_normal_full_covariance()
,
tfd_multivariate_normal_linear_operator()
,
tfd_multivariate_normal_tri_l()
,
tfd_multivariate_student_t_linear_operator()
,
tfd_negative_binomial()
,
tfd_normal()
,
tfd_one_hot_categorical()
,
tfd_pareto()
,
tfd_pixel_cnn()
,
tfd_poisson_log_normal_quadrature_compound()
,
tfd_poisson()
,
tfd_power_spherical()
,
tfd_probit_bernoulli()
,
tfd_quantized()
,
tfd_relaxed_bernoulli()
,
tfd_relaxed_one_hot_categorical()
,
tfd_sample_distribution()
,
tfd_sinh_arcsinh()
,
tfd_skellam()
,
tfd_spherical_uniform()
,
tfd_student_t_process()
,
tfd_student_t()
,
tfd_transformed_distribution()
,
tfd_triangular()
,
tfd_truncated_cauchy()
,
tfd_truncated_normal()
,
tfd_uniform()
,
tfd_variational_gaussian_process()
,
tfd_vector_diffeomixture()
,
tfd_vector_exponential_diag()
,
tfd_vector_exponential_linear_operator()
,
tfd_vector_laplace_diag()
,
tfd_vector_laplace_linear_operator()
,
tfd_vector_sinh_arcsinh_diag()
,
tfd_von_mises_fisher()
,
tfd_von_mises()
,
tfd_weibull()
,
tfd_wishart_linear_operator()
,
tfd_wishart_tri_l()
,
tfd_wishart()
,
tfd_zipf()