R/vi-functions.R
vi_csiszar_vimco.Rd
This function generalizes VIMCO (Mnih and Rezende, 2016) to Csiszar f-Divergences.
vi_csiszar_vimco( f, p_log_prob, q, num_draws, num_batch_draws = 1, seed = NULL, name = NULL )
f | function representing a Csiszar-function in log-space. |
---|---|
p_log_prob | function representing the natural-log of the
probability under distribution |
q |
|
num_draws | Integer scalar number of draws used to approximate the f-Divergence expectation. |
num_batch_draws | Integer scalar number of draws used to approximate the f-Divergence expectation. |
seed |
|
name | String prefixed to Ops created by this function. |
vimco The Csiszar f-Divergence generalized VIMCO objective
Note: if q.reparameterization_type = tfd.FULLY_REPARAMETERIZED
,
consider using monte_carlo_csiszar_f_divergence
.
The VIMCO loss is:
vimco = f(Avg{logu[i] : i=0,...,m-1}) where, logu[i] = log( p(x, h[i]) / q(h[i] | x) ) h[i] iid~ q(H | x)
Interestingly, the VIMCO gradient is not the naive gradient of vimco
.
Rather, it is characterized by:
grad[vimco] - variance_reducing_term
where,
variance_reducing_term = Sum{ grad[log q(h[i] | x)] * (vimco - f(log Avg{h[j;i] : j=0,...,m-1})) #' : i=0, ..., m-1 } h[j;i] = u[j] for j!=i, GeometricAverage{ u[k] : k!=i} for j==i
(We omitted stop_gradient
for brevity. See implementation for more details.)
The Avg{h[j;i] : j}
term is a kind of "swap-out average" where the i
-th
element has been replaced by the leave-i
-out Geometric-average.
This implementation prefers numerical precision over efficiency, i.e.,
O(num_draws * num_batch_draws * prod(batch_shape) * prod(event_shape))
.
(The constant may be fairly large, perhaps around 12.)
Other vi-functions:
vi_amari_alpha()
,
vi_arithmetic_geometric()
,
vi_chi_square()
,
vi_dual_csiszar_function()
,
vi_fit_surrogate_posterior()
,
vi_jeffreys()
,
vi_jensen_shannon()
,
vi_kl_forward()
,
vi_kl_reverse()
,
vi_log1p_abs()
,
vi_modified_gan()
,
vi_monte_carlo_variational_loss()
,
vi_pearson()
,
vi_squared_hellinger()
,
vi_symmetrized_csiszar_function()