A vector diffeomixture (VDM) is a distribution parameterized by a convex combination of K component loc vectors, loc[k], k = 0,...,K-1, and K scale matrices scale[k], k = 0,..., K-1. It approximates the following compound distribution p(x) = int p(x | z) p(z) dz, where z is in the K-simplex, and p(x | z) := p(x | loc=sum_k z[k] loc[k], scale=sum_k z[k] scale[k])

tfd_vector_diffeomixture(
  mix_loc,
  temperature,
  distribution,
  loc = NULL,
  scale = NULL,
  quadrature_size = 8,
  quadrature_fn = tfp$distributions$quadrature_scheme_softmaxnormal_quantiles,
  validate_args = FALSE,
  allow_nan_stats = TRUE,
  name = "VectorDiffeomixture"
)

Arguments

mix_loc

float-like Tensor with shape [b1, ..., bB, K-1]. In terms of samples, larger mix_loc[..., k] ==> Z is more likely to put more weight on its kth component.

temperature

float-like Tensor. Broadcastable with mix_loc. In terms of samples, smaller temperature means one component is more likely to dominate. I.e., smaller temperature makes the VDM look more like a standard mixture of K components.

distribution

tfp$distributions$Distribution-like instance. Distribution from which d iid samples are used as input to the selected affine transformation. Must be a scalar-batch, scalar-event distribution. Typically distribution$reparameterization_type = FULLY_REPARAMETERIZED or it is a function of non-trainable parameters. WARNING: If you backprop through a VectorDiffeomixture sample and the distribution is not FULLY_REPARAMETERIZED yet is a function of trainable variables, then the gradient will be incorrect!

loc

Length-K list of float-type Tensors. The k-th element represents the shift used for the k-th affine transformation. If the k-th item is NULL, loc is implicitly 0. When specified, must have shape [B1, ..., Bb, d] where b >= 0 and d is the event size.

scale

Length-K list of LinearOperators. Each should be positive-definite and operate on a d-dimensional vector space. The k-th element represents the scale used for the k-th affine transformation. LinearOperators must have shape [B1, ..., Bb, d, d], b >= 0, i.e., characterizes b-batches of d x d matrices

quadrature_size

integer scalar representing number of quadrature points. Larger quadrature_size means q_N(x) better approximates p(x).

quadrature_fn

Function taking normal_loc, normal_scale, quadrature_size, validate_args and returning tuple(grid, probs) representing the SoftmaxNormal grid and corresponding normalized weight. normalized) weight. Default value: quadrature_scheme_softmaxnormal_quantiles.

validate_args

Logical, default FALSE. When TRUE distribution parameters are checked for validity despite possibly degrading runtime performance. When FALSE invalid inputs may silently render incorrect outputs. Default value: FALSE.

allow_nan_stats

Logical, default TRUE. When TRUE, statistics (e.g., mean, mode, variance) use the value NaN to indicate the result is undefined. When FALSE, an exception is raised if one or more of the statistic's batch members are undefined.

name

name prefixed to Ops created by this class.

Value

a distribution instance.

Details

The integral int p(x | z) p(z) dz is approximated with a quadrature scheme adapted to the mixture density p(z). The N quadrature points z_{N, n} and weights w_{N, n} (which are non-negative and sum to 1) are chosen such that q_N(x) := sum_{n=1}^N w_{n, N} p(x | z_{N, n}) --> p(x) as N --> infinity.

Since q_N(x) is in fact a mixture (of N points), we may sample from q_N exactly. It is important to note that the VDM is defined as q_N above, and not p(x). Therefore, sampling and pdf may be implemented as exact (up to floating point error) methods.

A common choice for the conditional p(x | z) is a multivariate Normal. The implemented marginal p(z) is the SoftmaxNormal, which is a K-1 dimensional Normal transformed by a SoftmaxCentered bijector, making it a density on the K-simplex. That is, Z = SoftmaxCentered(X), X = Normal(mix_loc / temperature, 1 / temperature)

The default quadrature scheme chooses z_{N, n} as N midpoints of the quantiles of p(z) (generalized quantiles if K > 2). See Dillon and Langmore (2018) for more details.

About Vector distributions in TensorFlow.

The VectorDiffeomixture is a non-standard distribution that has properties particularly useful in variational Bayesian methods. Conditioned on a draw from the SoftmaxNormal, X|z is a vector whose components are linear combinations of affine transformations, thus is itself an affine transformation.

Note: The marginals X_1|v, ..., X_d|v are not generally identical to some parameterization of distribution. This is due to the fact that the sum of draws from distribution are not generally itself the same distribution.

About Diffeomixtures and reparameterization.

The VectorDiffeomixture is designed to be reparameterized, i.e., its parameters are only used to transform samples from a distribution which has no trainable parameters. This property is important because backprop stops at sources of stochasticity. That is, as long as the parameters are used after the underlying source of stochasticity, the computed gradient is accurate. Reparametrization means that we can use gradient-descent (via backprop) to optimize Monte-Carlo objectives. Such objectives are a finite-sample approximation of an expectation and arise throughout scientific computing.

WARNING: If you backprop through a VectorDiffeomixture sample and the "base" distribution is both: not FULLY_REPARAMETERIZED and a function of trainable variables, then the gradient is not guaranteed correct!

References

See also

For usage examples see e.g. tfd_sample(), tfd_log_prob(), tfd_mean().

Other distributions: tfd_autoregressive(), tfd_batch_reshape(), tfd_bates(), tfd_bernoulli(), tfd_beta_binomial(), tfd_beta(), tfd_binomial(), tfd_categorical(), tfd_cauchy(), tfd_chi2(), tfd_chi(), tfd_cholesky_lkj(), tfd_continuous_bernoulli(), tfd_deterministic(), tfd_dirichlet_multinomial(), tfd_dirichlet(), tfd_empirical(), tfd_exp_gamma(), tfd_exp_inverse_gamma(), tfd_exponential(), tfd_gamma_gamma(), tfd_gamma(), tfd_gaussian_process_regression_model(), tfd_gaussian_process(), tfd_generalized_normal(), tfd_geometric(), tfd_gumbel(), tfd_half_cauchy(), tfd_half_normal(), tfd_hidden_markov_model(), tfd_horseshoe(), tfd_independent(), tfd_inverse_gamma(), tfd_inverse_gaussian(), tfd_johnson_s_u(), tfd_joint_distribution_named_auto_batched(), tfd_joint_distribution_named(), tfd_joint_distribution_sequential_auto_batched(), tfd_joint_distribution_sequential(), tfd_kumaraswamy(), tfd_laplace(), tfd_linear_gaussian_state_space_model(), tfd_lkj(), tfd_log_logistic(), tfd_log_normal(), tfd_logistic(), tfd_mixture_same_family(), tfd_mixture(), tfd_multinomial(), tfd_multivariate_normal_diag_plus_low_rank(), tfd_multivariate_normal_diag(), tfd_multivariate_normal_full_covariance(), tfd_multivariate_normal_linear_operator(), tfd_multivariate_normal_tri_l(), tfd_multivariate_student_t_linear_operator(), tfd_negative_binomial(), tfd_normal(), tfd_one_hot_categorical(), tfd_pareto(), tfd_pixel_cnn(), tfd_poisson_log_normal_quadrature_compound(), tfd_poisson(), tfd_power_spherical(), tfd_probit_bernoulli(), tfd_quantized(), tfd_relaxed_bernoulli(), tfd_relaxed_one_hot_categorical(), tfd_sample_distribution(), tfd_sinh_arcsinh(), tfd_skellam(), tfd_spherical_uniform(), tfd_student_t_process(), tfd_student_t(), tfd_transformed_distribution(), tfd_triangular(), tfd_truncated_cauchy(), tfd_truncated_normal(), tfd_uniform(), tfd_variational_gaussian_process(), tfd_vector_exponential_diag(), tfd_vector_exponential_linear_operator(), tfd_vector_laplace_diag(), tfd_vector_laplace_linear_operator(), tfd_vector_sinh_arcsinh_diag(), tfd_von_mises_fisher(), tfd_von_mises(), tfd_weibull(), tfd_wishart_linear_operator(), tfd_wishart_tri_l(), tfd_wishart(), tfd_zipf()