Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC) algorithm
that takes a series of gradient-informed steps to produce a Metropolis
proposal. This class implements one random HMC step from a given
current_state. Mathematical details and derivations can be found in
Neal (2011).
mcmc_hamiltonian_monte_carlo( target_log_prob_fn, step_size, num_leapfrog_steps, state_gradients_are_stopped = FALSE, step_size_update_fn = NULL, seed = NULL, store_parameters_in_results = FALSE, name = NULL )
| target_log_prob_fn | Function which takes an argument like
|
|---|---|
| step_size |
|
| num_leapfrog_steps | Integer number of steps to run the leapfrog integrator
for. Total progress per HMC step is roughly proportional to
|
| state_gradients_are_stopped |
|
| step_size_update_fn | Function taking current |
| seed | integer to seed the random number generator. |
| store_parameters_in_results | If |
| name | string prefixed to Ops created by this function.
Default value: |
a Monte Carlo sampling kernel
The one_step function can update multiple chains in parallel. It assumes
that all leftmost dimensions of current_state index independent chain states
(and are therefore updated independently). The output of
target_log_prob_fn(current_state) should sum log-probabilities across all
event dimensions. Slices along the rightmost dimensions may have different
target distributions; for example, current_state[0, :] could have a
different target distribution from current_state[1, :]. These semantics are
governed by target_log_prob_fn(current_state). (The number of independent
chains is tf$size(target_log_prob_fn(current_state)).)
Other mcmc_kernels:
mcmc_dual_averaging_step_size_adaptation(),
mcmc_metropolis_adjusted_langevin_algorithm(),
mcmc_metropolis_hastings(),
mcmc_no_u_turn_sampler(),
mcmc_random_walk_metropolis(),
mcmc_replica_exchange_mc(),
mcmc_simple_step_size_adaptation(),
mcmc_slice_sampler(),
mcmc_transformed_transition_kernel(),
mcmc_uncalibrated_hamiltonian_monte_carlo(),
mcmc_uncalibrated_langevin(),
mcmc_uncalibrated_random_walk()