The No U-Turn Sampler (NUTS) is an adaptive variant of the Hamiltonian Monte Carlo (HMC) method for MCMC. NUTS adapts the distance traveled in response to the curvature of the target density. Conceptually, one proposal consists of reversibly evolving a trajectory through the sample space, continuing until that trajectory turns back on itself (hence the name, 'No U-Turn'). This class implements one random NUTS step from a given current_state. Mathematical details and derivations can be found in Hoffman & Gelman (2011).

mcmc_no_u_turn_sampler(
  target_log_prob_fn,
  step_size,
  max_tree_depth = 10,
  max_energy_diff = 1000,
  unrolled_leapfrog_steps = 1,
  seed = NULL,
  name = NULL
)

Arguments

target_log_prob_fn

function which takes an argument like current_state and returns its (possibly unnormalized) log-density under the target distribution.

step_size

Tensor or list of Tensors representing the step size for the leapfrog integrator. Must broadcast with the shape of current_state. Larger step sizes lead to faster progress, but too-large step sizes make rejection exponentially more likely. When possible, it's often helpful to match per-variable step sizes to the standard deviations of the target distribution in each variable.

max_tree_depth

Maximum depth of the tree implicitly built by NUTS. The maximum number of leapfrog steps is bounded by 2**max_tree_depth i.e. the number of nodes in a binary tree max_tree_depth nodes deep. The default setting of 10 takes up to 1024 leapfrog steps.

max_energy_diff

Scaler threshold of energy differences at each leapfrog, divergence samples are defined as leapfrog steps that exceed this threshold. Default to 1000.

unrolled_leapfrog_steps

The number of leapfrogs to unroll per tree expansion step. Applies a direct linear multipler to the maximum trajectory length implied by max_tree_depth. Defaults to 1.

seed

integer to seed the random number generator.

name

name prefixed to Ops created by this function. Default value: NULL (i.e., 'nuts_kernel').

Value

a Monte Carlo sampling kernel

Details

The one_step function can update multiple chains in parallel. It assumes that a prefix of leftmost dimensions of current_state index independent chain states (and are therefore updated independently). The output of target_log_prob_fn(current_state) should sum log-probabilities across all event dimensions. Slices along the rightmost dimensions may have different target distributions; for example, current_state[0][0, ...] could have a different target distribution from current_state[0][1, ...]. These semantics are governed by target_log_prob_fn(*current_state). (The number of independent chains is tf$size(target_log_prob_fn(current_state)).)

References

See also

Examples

# \donttest{ predictors <- tf$cast( c(201,244, 47,287,203,58,210,202,198,158,165,201,157, 131,166,160,186,125,218,146),tf$float32) obs <- tf$cast(c(592,401,583,402,495,173,479,504,510,416,393,442,317,311,400, 337,423,334,533,344),tf$float32) y_sigma <- tf$cast(c(61,25,38,15,21,15,27,14,30,16,14,25,52,16,34,31,42,26, 16,22),tf$float32) # Robust linear regression model robust_lm <- tfd_joint_distribution_sequential( list( tfd_normal(loc = 0, scale = 1, name = "b0"), tfd_normal(loc = 0, scale = 1, name = "b1"), tfd_half_normal(5, name = "df"), function(df, b1, b0) tfd_independent( tfd_student_t( # Likelihood df = tf$expand_dims(df, axis = -1L), loc = tf$expand_dims(b0, axis = -1L) + tf$expand_dims(b1, axis = -1L) * predictors[tf$newaxis, ], scale = y_sigma, name = "st" ), name = "ind")), validate_args = TRUE) log_prob <-function(b0, b1, df) {robust_lm %>% tfd_log_prob(list(b0, b1, df, obs))} step_size0 <- Map(function(x) tf$cast(x, tf$float32), c(1, .2, .5)) number_of_steps <- 10 burnin <- 5 nchain <- 50 run_chain <- function() { # random initialization of the starting postion of each chain samples <- robust_lm %>% tfd_sample(nchain) b0 <- samples[[1]] b1 <- samples[[2]] df <- samples[[3]] # bijector to map constrained parameters to real unconstraining_bijectors <- list( tfb_identity(), tfb_identity(), tfb_exp()) trace_fn <- function(x, pkr) { list(pkr$inner_results$inner_results$step_size, pkr$inner_results$inner_results$log_accept_ratio) } nuts <- mcmc_no_u_turn_sampler( target_log_prob_fn = log_prob, step_size = step_size0 ) %>% mcmc_transformed_transition_kernel(bijector = unconstraining_bijectors) %>% mcmc_dual_averaging_step_size_adaptation( num_adaptation_steps = burnin, step_size_setter_fn = function(pkr, new_step_size) pkr$`_replace`( inner_results = pkr$inner_results$`_replace`(step_size = new_step_size)), step_size_getter_fn = function(pkr) pkr$inner_results$step_size, log_accept_prob_getter_fn = function(pkr) pkr$inner_results$log_accept_ratio ) nuts %>% mcmc_sample_chain( num_results = number_of_steps, num_burnin_steps = burnin, current_state = list(b0, b1, df), trace_fn = trace_fn) } run_chain <- tensorflow::tf_function(run_chain) res <- run_chain() # }