This class provides automatic vectorization and alternative semantics for tfd_joint_distribution_named(), which in many cases allows for simplifications in the model specification.

tfd_joint_distribution_named_auto_batched(
  model,
  batch_ndims = 0,
  use_vectorized_map = TRUE,
  validate_args = FALSE,
  name = NULL
)

Arguments

model

A generator that yields a sequence of tfd$Distribution-like instances.

batch_ndims

integer Tensor number of batch dimensions. The batch_shapes of all component distributions must be such that the prefixes of length batch_ndims broadcast to a consistent joint batch shape. Default value: 0.

use_vectorized_map

logical. Whether to use tf$vectorized_map to automatically vectorize evaluation of the model. This allows the model specification to focus on drawing a single sample, which is often simpler, but some ops may not be supported. Default value: TRUE.

validate_args

Logical, default FALSE. When TRUE distribution parameters are checked for validity despite possibly degrading runtime performance. When FALSE invalid inputs may silently render incorrect outputs. Default value: FALSE.

name

name prefixed to Ops created by this class.

Value

a distribution instance.

Details

Automatic vectorization

Auto-vectorized variants of JointDistribution allow the user to avoid explicitly annotating a model's vectorization semantics. When using manually-vectorized joint distributions, each operation in the model must account for the possibility of batch dimensions in Distributions and their samples. By contrast, auto-vectorized models need only describe a single sample from the joint distribution; any batch evaluation is automated using tf$vectorized_map as required. In many cases this allows for significant simplications. For example, the following manually-vectorized tfd_joint_distribution_named() model:

model <- tfd_joint_distribution_sequential(
    list(
      x = tfd_normal(loc = 0, scale = tf$ones(3L)),
      y = tfd_normal(loc = 0, scale = 1),
      z = function(y, x) {
        tfd_normal(loc = x[reticulate::py_ellipsis(), 1:2] + y[reticulate::py_ellipsis(), tf$newaxis], scale = 1)
      }
    )
)

can be written in auto-vectorized form as

model <- tfd_joint_distribution_sequential_auto_batched(
  list(
    x = tfd_normal(loc = 0, scale = tf$ones(3L)),
    y = tfd_normal(loc = 0, scale = 1),
    z = function(y, x) {tfd_normal(loc = x[1:2] + y, scale = 1)}
  )
)

in which we were able to avoid explicitly accounting for batch dimensions when indexing and slicing computed quantities in the third line. Note: auto-vectorization is still experimental and some TensorFlow ops may be unsupported. It can be disabled by setting use_vectorized_map=FALSE.

Alternative batch semantics This class also provides alternative semantics for specifying a batch of independent (non-identical) joint distributions. Instead of simply summing the log_probs of component distributions (which may have different shapes), it first reduces the component log_probs to ensure that jd$log_prob(jd$sample()) always returns a scalar, unless batch_ndims is explicitly set to a nonzero value (in which case the result will have the corresponding tensor rank).

The essential changes are:

  • An event of JointDistributionNamedAutoBatched is the list of tensors produced by $sample(); thus, the event_shape is the list containing the shapes of sampled tensors. These combine both the event and batch dimensions of the component distributions. By contrast, the event shape of a base JointDistributions does not include batch dimensions of component distributions.

  • The batch_shape is a global property of the entire model, rather than a per-component property as in base JointDistributions. The global batch shape must be a prefix of the batch shapes of each component; the length of this prefix is specified by an optional argument batch_ndims. If batch_ndims is not specified, the model has batch shape ().#'

See also

For usage examples see e.g. tfd_sample(), tfd_log_prob(), tfd_mean().

Other distributions: tfd_autoregressive(), tfd_batch_reshape(), tfd_bates(), tfd_bernoulli(), tfd_beta_binomial(), tfd_beta(), tfd_binomial(), tfd_categorical(), tfd_cauchy(), tfd_chi2(), tfd_chi(), tfd_cholesky_lkj(), tfd_continuous_bernoulli(), tfd_deterministic(), tfd_dirichlet_multinomial(), tfd_dirichlet(), tfd_empirical(), tfd_exp_gamma(), tfd_exp_inverse_gamma(), tfd_exponential(), tfd_gamma_gamma(), tfd_gamma(), tfd_gaussian_process_regression_model(), tfd_gaussian_process(), tfd_generalized_normal(), tfd_geometric(), tfd_gumbel(), tfd_half_cauchy(), tfd_half_normal(), tfd_hidden_markov_model(), tfd_horseshoe(), tfd_independent(), tfd_inverse_gamma(), tfd_inverse_gaussian(), tfd_johnson_s_u(), tfd_joint_distribution_named(), tfd_joint_distribution_sequential_auto_batched(), tfd_joint_distribution_sequential(), tfd_kumaraswamy(), tfd_laplace(), tfd_linear_gaussian_state_space_model(), tfd_lkj(), tfd_log_logistic(), tfd_log_normal(), tfd_logistic(), tfd_mixture_same_family(), tfd_mixture(), tfd_multinomial(), tfd_multivariate_normal_diag_plus_low_rank(), tfd_multivariate_normal_diag(), tfd_multivariate_normal_full_covariance(), tfd_multivariate_normal_linear_operator(), tfd_multivariate_normal_tri_l(), tfd_multivariate_student_t_linear_operator(), tfd_negative_binomial(), tfd_normal(), tfd_one_hot_categorical(), tfd_pareto(), tfd_pixel_cnn(), tfd_poisson_log_normal_quadrature_compound(), tfd_poisson(), tfd_power_spherical(), tfd_probit_bernoulli(), tfd_quantized(), tfd_relaxed_bernoulli(), tfd_relaxed_one_hot_categorical(), tfd_sample_distribution(), tfd_sinh_arcsinh(), tfd_skellam(), tfd_spherical_uniform(), tfd_student_t_process(), tfd_student_t(), tfd_transformed_distribution(), tfd_triangular(), tfd_truncated_cauchy(), tfd_truncated_normal(), tfd_uniform(), tfd_variational_gaussian_process(), tfd_vector_diffeomixture(), tfd_vector_exponential_diag(), tfd_vector_exponential_linear_operator(), tfd_vector_laplace_diag(), tfd_vector_laplace_linear_operator(), tfd_vector_sinh_arcsinh_diag(), tfd_von_mises_fisher(), tfd_von_mises(), tfd_weibull(), tfd_wishart_linear_operator(), tfd_wishart_tri_l(), tfd_wishart(), tfd_zipf()