This function uses an MCMC transition operator (e.g., Hamiltonian Monte Carlo) to sample from a series of distributions that slowly interpolates between an initial "proposal" distribution: exp(proposal_log_prob_fn(x) - proposal_log_normalizer) and the target distribution: exp(target_log_prob_fn(x) - target_log_normalizer), accumulating importance weights along the way. The product of these importance weights gives an unbiased estimate of the ratio of the normalizing constants of the initial distribution and the target distribution: E[exp(ais_weights)] = exp(target_log_normalizer - proposal_log_normalizer).

mcmc_sample_annealed_importance_chain(
  num_steps,
  proposal_log_prob_fn,
  target_log_prob_fn,
  current_state,
  make_kernel_fn,
  parallel_iterations = 10,
  name = NULL
)

Arguments

num_steps

Integer number of Markov chain updates to run. More iterations means more expense, but smoother annealing between q and p, which in turn means exponentially lower variance for the normalizing constant estimator.

proposal_log_prob_fn

function that returns the log density of the initial distribution.

target_log_prob_fn

function which takes an argument like current_state and returns its (possibly unnormalized) log-density under the target distribution.

current_state

Tensor or list of Tensors representing the current state(s) of the Markov chain(s). The first r dimensions index independent chains, r = tf$rank(target_log_prob_fn(current_state)).

make_kernel_fn

function which returns a TransitionKernel-like object. Must take one argument representing the TransitionKernel's target_log_prob_fn. The target_log_prob_fn argument represents the TransitionKernel's target log distribution. Note: sample_annealed_importance_chain creates a new target_log_prob_fn which is an interpolation between the supplied target_log_prob_fn and proposal_log_prob_fn; it is this interpolated function which is used as an argument to make_kernel_fn.

parallel_iterations

The number of iterations allowed to run in parallel. It must be a positive integer. See tf$while_loop for more details.

name

string prefixed to Ops created by this function. Default value: NULL (i.e., "sample_annealed_importance_chain").

Value

list of next_state (Tensor or Python list of Tensors representing the state(s) of the Markov chain(s) at the final iteration. Has same shape as input current_state), ais_weights (Tensor with the estimated weight(s). Has shape matching target_log_prob_fn(current_state)), and kernel_results (collections.namedtuple of internal calculations used to advance the chain).

Details

Note: When running in graph mode, proposal_log_prob_fn and target_log_prob_fn are called exactly three times (although this may be reduced to two times in the future).

See also