This Multinomial distribution is parameterized by probs, a (batch of) length-K prob (probability) vectors (K > 1) such that tf.reduce_sum(probs, -1) = 1, and a total_count number of trials, i.e., the number of trials per draw from the Multinomial. It is defined over a (batch of) length-K vector counts such that tf$reduce_sum(counts, -1) = total_count. The Multinomial is identically the Binomial distribution when K = 2.

  logits = NULL,
  probs = NULL,
  validate_args = FALSE,
  allow_nan_stats = TRUE,
  name = "Multinomial"



Non-negative floating point tensor with shape broadcastable to [N1,..., Nm] with m >= 0. Defines this as a batch of N1 x ... x Nm different Multinomial distributions. Its components should be equal to integer values.


Floating point tensor representing unnormalized log-probabilities of a positive event with shape broadcastable to [N1,..., Nm, K] m >= 0, and the same dtype as total_count. Defines this as a batch of N1 x ... x Nm different K class Multinomial distributions. Only one of logits or probs should be passed in.


Positive floating point tensor with shape broadcastable to [N1,..., Nm, K] m >= 0 and same dtype as total_count. Defines this as a batch of N1 x ... x Nm different K class Multinomial distributions. probs's components in the last portion of its shape should sum to 1. Only one of logits or probs should be passed in.


Logical, default FALSE. When TRUE distribution parameters are checked for validity despite possibly degrading runtime performance. When FALSE invalid inputs may silently render incorrect outputs. Default value: FALSE.


Logical, default TRUE. When TRUE, statistics (e.g., mean, mode, variance) use the value NaN to indicate the result is undefined. When FALSE, an exception is raised if one or more of the statistic's batch members are undefined.


name prefixed to Ops created by this class.


a distribution instance.


Mathematical Details

The Multinomial is a distribution over K-class counts, i.e., a length-K vector of non-negative integer counts = n = [n_0, ..., n_{K-1}]. The probability mass function (pmf) is,

pmf(n; pi, N) = prod_j (pi_j)**n_j / Z
Z = (prod_j n_j!) / N!


  • probs = pi = [pi_0, ..., pi_{K-1}], pi_j > 0, sum_j pi_j = 1,

  • total_count = N, N a positive integer,

  • Z is the normalization constant, and,

  • N! denotes N factorial.

Distribution parameters are automatically broadcast in all functions; see examples for details.


The number of classes, K, must not exceed:

  • the largest integer representable by self$dtype, i.e., 2**(mantissa_bits+1) (IEE754),

  • the maximum Tensor index, i.e., 2**31-1.

Note: This condition is validated only when validate_args = TRUE.

See also

For usage examples see e.g. tfd_sample(), tfd_log_prob(), tfd_mean().

Other distributions: tfd_autoregressive(), tfd_batch_reshape(), tfd_bates(), tfd_bernoulli(), tfd_beta_binomial(), tfd_beta(), tfd_binomial(), tfd_categorical(), tfd_cauchy(), tfd_chi2(), tfd_chi(), tfd_cholesky_lkj(), tfd_continuous_bernoulli(), tfd_deterministic(), tfd_dirichlet_multinomial(), tfd_dirichlet(), tfd_empirical(), tfd_exp_gamma(), tfd_exp_inverse_gamma(), tfd_exponential(), tfd_gamma_gamma(), tfd_gamma(), tfd_gaussian_process_regression_model(), tfd_gaussian_process(), tfd_generalized_normal(), tfd_geometric(), tfd_gumbel(), tfd_half_cauchy(), tfd_half_normal(), tfd_hidden_markov_model(), tfd_horseshoe(), tfd_independent(), tfd_inverse_gamma(), tfd_inverse_gaussian(), tfd_johnson_s_u(), tfd_joint_distribution_named_auto_batched(), tfd_joint_distribution_named(), tfd_joint_distribution_sequential_auto_batched(), tfd_joint_distribution_sequential(), tfd_kumaraswamy(), tfd_laplace(), tfd_linear_gaussian_state_space_model(), tfd_lkj(), tfd_log_logistic(), tfd_log_normal(), tfd_logistic(), tfd_mixture_same_family(), tfd_mixture(), tfd_multivariate_normal_diag_plus_low_rank(), tfd_multivariate_normal_diag(), tfd_multivariate_normal_full_covariance(), tfd_multivariate_normal_linear_operator(), tfd_multivariate_normal_tri_l(), tfd_multivariate_student_t_linear_operator(), tfd_negative_binomial(), tfd_normal(), tfd_one_hot_categorical(), tfd_pareto(), tfd_pixel_cnn(), tfd_poisson_log_normal_quadrature_compound(), tfd_poisson(), tfd_power_spherical(), tfd_probit_bernoulli(), tfd_quantized(), tfd_relaxed_bernoulli(), tfd_relaxed_one_hot_categorical(), tfd_sample_distribution(), tfd_sinh_arcsinh(), tfd_skellam(), tfd_spherical_uniform(), tfd_student_t_process(), tfd_student_t(), tfd_transformed_distribution(), tfd_triangular(), tfd_truncated_cauchy(), tfd_truncated_normal(), tfd_uniform(), tfd_variational_gaussian_process(), tfd_vector_diffeomixture(), tfd_vector_exponential_diag(), tfd_vector_exponential_linear_operator(), tfd_vector_laplace_diag(), tfd_vector_laplace_linear_operator(), tfd_vector_sinh_arcsinh_diag(), tfd_von_mises_fisher(), tfd_von_mises(), tfd_weibull(), tfd_wishart_linear_operator(), tfd_wishart_tri_l(), tfd_wishart(), tfd_zipf()