This Multinomial distribution is parameterized by probs
, a (batch of)
length-K
prob
(probability) vectors (K > 1
) such that
tf.reduce_sum(probs, -1) = 1
, and a total_count
number of trials, i.e.,
the number of trials per draw from the Multinomial. It is defined over a
(batch of) length-K
vector counts
such that
tf$reduce_sum(counts, -1) = total_count
. The Multinomial is identically the
Binomial distribution when K = 2
.
tfd_multinomial( total_count, logits = NULL, probs = NULL, validate_args = FALSE, allow_nan_stats = TRUE, name = "Multinomial" )
total_count | Non-negative floating point tensor with shape broadcastable
to |
---|---|
logits | Floating point tensor representing unnormalized log-probabilities
of a positive event with shape broadcastable to
|
probs | Positive floating point tensor with shape broadcastable to
|
validate_args | Logical, default FALSE. When TRUE distribution parameters are checked for validity despite possibly degrading runtime performance. When FALSE invalid inputs may silently render incorrect outputs. Default value: FALSE. |
allow_nan_stats | Logical, default TRUE. When TRUE, statistics (e.g., mean, mode, variance) use the value NaN to indicate the result is undefined. When FALSE, an exception is raised if one or more of the statistic's batch members are undefined. |
name | name prefixed to Ops created by this class. |
a distribution instance.
Mathematical Details
The Multinomial is a distribution over K
-class counts, i.e., a length-K
vector of non-negative integer counts = n = [n_0, ..., n_{K-1}]
.
The probability mass function (pmf) is,
pmf(n; pi, N) = prod_j (pi_j)**n_j / Z Z = (prod_j n_j!) / N!
where:
probs = pi = [pi_0, ..., pi_{K-1}]
, pi_j > 0
, sum_j pi_j = 1
,
total_count = N
, N
a positive integer,
Z
is the normalization constant, and,
N!
denotes N
factorial.
Distribution parameters are automatically broadcast in all functions; see examples for details.
Pitfalls
The number of classes, K
, must not exceed:
the largest integer representable by self$dtype
, i.e.,
2**(mantissa_bits+1)
(IEE754),
the maximum Tensor
index, i.e., 2**31-1
.
Note: This condition is validated only when validate_args = TRUE
.
For usage examples see e.g. tfd_sample()
, tfd_log_prob()
, tfd_mean()
.
Other distributions:
tfd_autoregressive()
,
tfd_batch_reshape()
,
tfd_bates()
,
tfd_bernoulli()
,
tfd_beta_binomial()
,
tfd_beta()
,
tfd_binomial()
,
tfd_categorical()
,
tfd_cauchy()
,
tfd_chi2()
,
tfd_chi()
,
tfd_cholesky_lkj()
,
tfd_continuous_bernoulli()
,
tfd_deterministic()
,
tfd_dirichlet_multinomial()
,
tfd_dirichlet()
,
tfd_empirical()
,
tfd_exp_gamma()
,
tfd_exp_inverse_gamma()
,
tfd_exponential()
,
tfd_gamma_gamma()
,
tfd_gamma()
,
tfd_gaussian_process_regression_model()
,
tfd_gaussian_process()
,
tfd_generalized_normal()
,
tfd_geometric()
,
tfd_gumbel()
,
tfd_half_cauchy()
,
tfd_half_normal()
,
tfd_hidden_markov_model()
,
tfd_horseshoe()
,
tfd_independent()
,
tfd_inverse_gamma()
,
tfd_inverse_gaussian()
,
tfd_johnson_s_u()
,
tfd_joint_distribution_named_auto_batched()
,
tfd_joint_distribution_named()
,
tfd_joint_distribution_sequential_auto_batched()
,
tfd_joint_distribution_sequential()
,
tfd_kumaraswamy()
,
tfd_laplace()
,
tfd_linear_gaussian_state_space_model()
,
tfd_lkj()
,
tfd_log_logistic()
,
tfd_log_normal()
,
tfd_logistic()
,
tfd_mixture_same_family()
,
tfd_mixture()
,
tfd_multivariate_normal_diag_plus_low_rank()
,
tfd_multivariate_normal_diag()
,
tfd_multivariate_normal_full_covariance()
,
tfd_multivariate_normal_linear_operator()
,
tfd_multivariate_normal_tri_l()
,
tfd_multivariate_student_t_linear_operator()
,
tfd_negative_binomial()
,
tfd_normal()
,
tfd_one_hot_categorical()
,
tfd_pareto()
,
tfd_pixel_cnn()
,
tfd_poisson_log_normal_quadrature_compound()
,
tfd_poisson()
,
tfd_power_spherical()
,
tfd_probit_bernoulli()
,
tfd_quantized()
,
tfd_relaxed_bernoulli()
,
tfd_relaxed_one_hot_categorical()
,
tfd_sample_distribution()
,
tfd_sinh_arcsinh()
,
tfd_skellam()
,
tfd_spherical_uniform()
,
tfd_student_t_process()
,
tfd_student_t()
,
tfd_transformed_distribution()
,
tfd_triangular()
,
tfd_truncated_cauchy()
,
tfd_truncated_normal()
,
tfd_uniform()
,
tfd_variational_gaussian_process()
,
tfd_vector_diffeomixture()
,
tfd_vector_exponential_diag()
,
tfd_vector_exponential_linear_operator()
,
tfd_vector_laplace_diag()
,
tfd_vector_laplace_linear_operator()
,
tfd_vector_sinh_arcsinh_diag()
,
tfd_von_mises_fisher()
,
tfd_von_mises()
,
tfd_weibull()
,
tfd_wishart_linear_operator()
,
tfd_wishart_tri_l()
,
tfd_wishart()
,
tfd_zipf()