The Dirichlet-Multinomial distribution is parameterized by a (batch of)
length-K concentration vectors (K > 1) and a total_count number of
trials, i.e., the number of trials per draw from the DirichletMultinomial. It
is defined over a (batch of) length-K vector counts such that
tf$reduce_sum(counts, -1) = total_count. The Dirichlet-Multinomial is
identically the Beta-Binomial distribution when K = 2.
tfd_dirichlet_multinomial( total_count, concentration, validate_args = FALSE, allow_nan_stats = TRUE, name = "DirichletMultinomial" )
| total_count | Non-negative floating point tensor, whose dtype is the same
as |
|---|---|
| concentration | Positive floating point tensor, whose dtype is the
same as |
| validate_args | Logical, default FALSE. When TRUE distribution parameters are checked for validity despite possibly degrading runtime performance. When FALSE invalid inputs may silently render incorrect outputs. Default value: FALSE. |
| allow_nan_stats | Logical, default TRUE. When TRUE, statistics (e.g., mean, mode, variance) use the value NaN to indicate the result is undefined. When FALSE, an exception is raised if one or more of the statistic's batch members are undefined. |
| name | name prefixed to Ops created by this class. |
a distribution instance.
Mathematical Details
The Dirichlet-Multinomial is a distribution over K-class counts, i.e., a
length-K vector of non-negative integer counts = n = [n_0, ..., n_{K-1}].
The probability mass function (pmf) is,
pmf(n; alpha, N) = Beta(alpha + n) / (prod_j n_j!) / Z Z = Beta(alpha) / N!
where:
concentration = alpha = [alpha_0, ..., alpha_{K-1}], alpha_j > 0,
total_count = N, N a positive integer,
N! is N factorial, and,
Beta(x) = prod_j Gamma(x_j) / Gamma(sum_j x_j) is the
multivariate beta function,
and,
Gamma is the gamma function.
Dirichlet-Multinomial is a compound distribution, i.e., its samples are generated as follows.
Choose class probabilities:
probs = [p_0,...,p_{K-1}] ~ Dir(concentration)
Draw integers:
counts = [n_0,...,n_{K-1}] ~ Multinomial(total_count, probs)
The last concentration dimension parametrizes a single Dirichlet-Multinomial
distribution. When calling distribution functions (e.g., dist$prob(counts)),
concentration, total_count and counts are broadcast to the same shape.
The last dimension of counts corresponds single Dirichlet-Multinomial distributions.
Distribution parameters are automatically broadcast in all functions; see examples for details.
Pitfalls
The number of classes, K, must not exceed:
the largest integer representable by self$dtype, i.e.,
2**(mantissa_bits+1) (IEE754),
the maximum Tensor index, i.e., 2**31-1.
Note: This condition is validated only when validate_args = TRUE.
For usage examples see e.g. tfd_sample(), tfd_log_prob(), tfd_mean().
Other distributions:
tfd_autoregressive(),
tfd_batch_reshape(),
tfd_bates(),
tfd_bernoulli(),
tfd_beta_binomial(),
tfd_beta(),
tfd_binomial(),
tfd_categorical(),
tfd_cauchy(),
tfd_chi2(),
tfd_chi(),
tfd_cholesky_lkj(),
tfd_continuous_bernoulli(),
tfd_deterministic(),
tfd_dirichlet(),
tfd_empirical(),
tfd_exp_gamma(),
tfd_exp_inverse_gamma(),
tfd_exponential(),
tfd_gamma_gamma(),
tfd_gamma(),
tfd_gaussian_process_regression_model(),
tfd_gaussian_process(),
tfd_generalized_normal(),
tfd_geometric(),
tfd_gumbel(),
tfd_half_cauchy(),
tfd_half_normal(),
tfd_hidden_markov_model(),
tfd_horseshoe(),
tfd_independent(),
tfd_inverse_gamma(),
tfd_inverse_gaussian(),
tfd_johnson_s_u(),
tfd_joint_distribution_named_auto_batched(),
tfd_joint_distribution_named(),
tfd_joint_distribution_sequential_auto_batched(),
tfd_joint_distribution_sequential(),
tfd_kumaraswamy(),
tfd_laplace(),
tfd_linear_gaussian_state_space_model(),
tfd_lkj(),
tfd_log_logistic(),
tfd_log_normal(),
tfd_logistic(),
tfd_mixture_same_family(),
tfd_mixture(),
tfd_multinomial(),
tfd_multivariate_normal_diag_plus_low_rank(),
tfd_multivariate_normal_diag(),
tfd_multivariate_normal_full_covariance(),
tfd_multivariate_normal_linear_operator(),
tfd_multivariate_normal_tri_l(),
tfd_multivariate_student_t_linear_operator(),
tfd_negative_binomial(),
tfd_normal(),
tfd_one_hot_categorical(),
tfd_pareto(),
tfd_pixel_cnn(),
tfd_poisson_log_normal_quadrature_compound(),
tfd_poisson(),
tfd_power_spherical(),
tfd_probit_bernoulli(),
tfd_quantized(),
tfd_relaxed_bernoulli(),
tfd_relaxed_one_hot_categorical(),
tfd_sample_distribution(),
tfd_sinh_arcsinh(),
tfd_skellam(),
tfd_spherical_uniform(),
tfd_student_t_process(),
tfd_student_t(),
tfd_transformed_distribution(),
tfd_triangular(),
tfd_truncated_cauchy(),
tfd_truncated_normal(),
tfd_uniform(),
tfd_variational_gaussian_process(),
tfd_vector_diffeomixture(),
tfd_vector_exponential_diag(),
tfd_vector_exponential_linear_operator(),
tfd_vector_laplace_diag(),
tfd_vector_laplace_linear_operator(),
tfd_vector_sinh_arcsinh_diag(),
tfd_von_mises_fisher(),
tfd_von_mises(),
tfd_weibull(),
tfd_wishart_linear_operator(),
tfd_wishart_tri_l(),
tfd_wishart(),
tfd_zipf()