R/vi-functions.R
vi_csiszar_vimco.RdThis function generalizes VIMCO (Mnih and Rezende, 2016) to Csiszar f-Divergences.
vi_csiszar_vimco( f, p_log_prob, q, num_draws, num_batch_draws = 1, seed = NULL, name = NULL )
| f | function representing a Csiszar-function in log-space. |
|---|---|
| p_log_prob | function representing the natural-log of the
probability under distribution |
| q |
|
| num_draws | Integer scalar number of draws used to approximate the f-Divergence expectation. |
| num_batch_draws | Integer scalar number of draws used to approximate the f-Divergence expectation. |
| seed |
|
| name | String prefixed to Ops created by this function. |
vimco The Csiszar f-Divergence generalized VIMCO objective
Note: if q.reparameterization_type = tfd.FULLY_REPARAMETERIZED,
consider using monte_carlo_csiszar_f_divergence.
The VIMCO loss is:
vimco = f(Avg{logu[i] : i=0,...,m-1})
where,
logu[i] = log( p(x, h[i]) / q(h[i] | x) )
h[i] iid~ q(H | x)
Interestingly, the VIMCO gradient is not the naive gradient of vimco.
Rather, it is characterized by:
grad[vimco] - variance_reducing_term
where,
variance_reducing_term = Sum{ grad[log q(h[i] | x)] * (vimco - f(log Avg{h[j;i] : j=0,...,m-1})) #' : i=0, ..., m-1 }
h[j;i] = u[j] for j!=i, GeometricAverage{ u[k] : k!=i} for j==i
(We omitted stop_gradient for brevity. See implementation for more details.)
The Avg{h[j;i] : j} term is a kind of "swap-out average" where the i-th
element has been replaced by the leave-i-out Geometric-average.
This implementation prefers numerical precision over efficiency, i.e.,
O(num_draws * num_batch_draws * prod(batch_shape) * prod(event_shape)).
(The constant may be fairly large, perhaps around 12.)
Other vi-functions:
vi_amari_alpha(),
vi_arithmetic_geometric(),
vi_chi_square(),
vi_dual_csiszar_function(),
vi_fit_surrogate_posterior(),
vi_jeffreys(),
vi_jensen_shannon(),
vi_kl_forward(),
vi_kl_reverse(),
vi_log1p_abs(),
vi_modified_gan(),
vi_monte_carlo_variational_loss(),
vi_pearson(),
vi_squared_hellinger(),
vi_symmetrized_csiszar_function()