This function generalizes VIMCO (Mnih and Rezende, 2016) to Csiszar f-Divergences.

vi_csiszar_vimco(
  f,
  p_log_prob,
  q,
  num_draws,
  num_batch_draws = 1,
  seed = NULL,
  name = NULL
)

Arguments

f

function representing a Csiszar-function in log-space.

p_log_prob

function representing the natural-log of the probability under distribution p. (In variational inference p is the joint distribution.)

q

tfd$Distribution-like instance; must implement: sample(n, seed), and log_prob(x). (In variational inference q is the approximate posterior distribution.)

num_draws

Integer scalar number of draws used to approximate the f-Divergence expectation.

num_batch_draws

Integer scalar number of draws used to approximate the f-Divergence expectation.

seed

integer seed for q$sample.

name

String prefixed to Ops created by this function.

Value

vimco The Csiszar f-Divergence generalized VIMCO objective

Details

Note: if q.reparameterization_type = tfd.FULLY_REPARAMETERIZED, consider using monte_carlo_csiszar_f_divergence.

The VIMCO loss is:

vimco = f(Avg{logu[i] : i=0,...,m-1})
where,
logu[i] = log( p(x, h[i]) / q(h[i] | x) )
h[i] iid~ q(H | x)

Interestingly, the VIMCO gradient is not the naive gradient of vimco. Rather, it is characterized by:

grad[vimco] - variance_reducing_term

where,

variance_reducing_term = Sum{ grad[log q(h[i] | x)] * (vimco - f(log Avg{h[j;i] : j=0,...,m-1})) #' : i=0, ..., m-1 }
h[j;i] =  u[j]  for j!=i,  GeometricAverage{ u[k] : k!=i} for j==i

(We omitted stop_gradient for brevity. See implementation for more details.) The Avg{h[j;i] : j} term is a kind of "swap-out average" where the i-th element has been replaced by the leave-i-out Geometric-average.

This implementation prefers numerical precision over efficiency, i.e., O(num_draws * num_batch_draws * prod(batch_shape) * prod(event_shape)). (The constant may be fairly large, perhaps around 12.)

References

See also