Real NVP models a normalizing flow on a D-dimensional distribution via a
single D-d-dimensional conditional distribution (Dinh et al., 2017):
y[d:D] = x[d:D] * tf.exp(log_scale_fn(x[0:d])) + shift_fn(x[0:d])
y[0:d] = x[0:d]
The last D-d units are scaled and shifted based on the first d units only,
while the first d units are 'masked' and left unchanged. Real NVP's
shift_and_log_scale_fn computes vector-valued quantities.
For scale-and-shift transforms that do not depend on any masked units, i.e.
d=0, use the tfb_affine bijector with learned parameters instead.
Masking is currently only supported for base distributions with
event_ndims=1. For more sophisticated masking schemes like checkerboard or
channel-wise masking (Papamakarios et al., 2016), use the tfb_permute
bijector to re-order desired masked units into the first d units. For base
distributions with event_ndims > 1, use the tfb_reshape bijector to
flatten the event shape.
tfb_real_nvp( num_masked, shift_and_log_scale_fn, is_constant_jacobian = FALSE, validate_args = FALSE, name = NULL )
num_masked | integer indicating that the first d units of the event
should be masked. Must be in the closed interval |
---|---|
shift_and_log_scale_fn | Function which computes shift and log_scale from both the
forward domain (x) and the inverse domain (y).
Calculation must respect the "autoregressive property". Suggested default:
|
is_constant_jacobian | Logical, default: FALSE. When TRUE the implementation assumes log_scale does not depend on the forward domain (x) or inverse domain (y) values. (No validation is made; is_constant_jacobian=FALSE is always safe but possibly computationally inefficient.) |
validate_args | Logical, default FALSE. Whether to validate input with asserts. If validate_args is FALSE, and the inputs are invalid, correct behavior is not guaranteed. |
name | name prefixed to Ops created by this class. |
a bijector instance.
Recall that the MAF bijector (Papamakarios et al., 2016) implements a normalizing flow via an autoregressive transformation. MAF and IAF have opposite computational tradeoffs - MAF can train all units in parallel but must sample units sequentially, while IAF must train units sequentially but can sample in parallel. In contrast, Real NVP can compute both forward and inverse computations in parallel. However, the lack of an autoregressive transformations makes it less expressive on a per-bijector basis.
A "valid" shift_and_log_scale_fn must compute each shift (aka loc or "mu" in Papamakarios et al. (2016) and log(scale) (aka "alpha" in Papamakarios et al. (2016)) such that each are broadcastable with the arguments to forward and inverse, i.e., such that the calculations in forward, inverse below are possible. For convenience, real_nvp_default_nvp is offered as a possible shift_and_log_scale_fn function.
NICE (Dinh et al., 2014) is a special case of the Real NVP bijector which discards the scale transformation, resulting in a constant-time inverse-log-determinant-Jacobian. To use a NICE bijector instead of Real NVP, shift_and_log_scale_fn should return (shift, NULL), and is_constant_jacobian should be set to TRUE in the RealNVP constructor. Calling tfb_real_nvp_default_template with shift_only=TRUE returns one such NICE-compatible shift_and_log_scale_fn.
Caching: the scalar input depth D of the base distribution is not known at
construction time. The first call to any of forward(x), inverse(x),
inverse_log_det_jacobian(x), or forward_log_det_jacobian(x) memoizes
D, which is re-used in subsequent calls. This shape must be known prior to
graph execution (which is the case if using tf$layers
).
For usage examples see tfb_forward()
, tfb_inverse()
, tfb_inverse_log_det_jacobian()
.
Other bijectors:
tfb_absolute_value()
,
tfb_affine_linear_operator()
,
tfb_affine_scalar()
,
tfb_affine()
,
tfb_ascending()
,
tfb_batch_normalization()
,
tfb_blockwise()
,
tfb_chain()
,
tfb_cholesky_outer_product()
,
tfb_cholesky_to_inv_cholesky()
,
tfb_correlation_cholesky()
,
tfb_cumsum()
,
tfb_discrete_cosine_transform()
,
tfb_expm1()
,
tfb_exp()
,
tfb_ffjord()
,
tfb_fill_scale_tri_l()
,
tfb_fill_triangular()
,
tfb_glow()
,
tfb_gompertz_cdf()
,
tfb_gumbel_cdf()
,
tfb_gumbel()
,
tfb_identity()
,
tfb_inline()
,
tfb_invert()
,
tfb_iterated_sigmoid_centered()
,
tfb_kumaraswamy_cdf()
,
tfb_kumaraswamy()
,
tfb_lambert_w_tail()
,
tfb_masked_autoregressive_default_template()
,
tfb_masked_autoregressive_flow()
,
tfb_masked_dense()
,
tfb_matrix_inverse_tri_l()
,
tfb_matvec_lu()
,
tfb_normal_cdf()
,
tfb_ordered()
,
tfb_pad()
,
tfb_permute()
,
tfb_power_transform()
,
tfb_rational_quadratic_spline()
,
tfb_rayleigh_cdf()
,
tfb_real_nvp_default_template()
,
tfb_reciprocal()
,
tfb_reshape()
,
tfb_scale_matvec_diag()
,
tfb_scale_matvec_linear_operator()
,
tfb_scale_matvec_lu()
,
tfb_scale_matvec_tri_l()
,
tfb_scale_tri_l()
,
tfb_scale()
,
tfb_shifted_gompertz_cdf()
,
tfb_shift()
,
tfb_sigmoid()
,
tfb_sinh_arcsinh()
,
tfb_sinh()
,
tfb_softmax_centered()
,
tfb_softplus()
,
tfb_softsign()
,
tfb_split()
,
tfb_square()
,
tfb_tanh()
,
tfb_transform_diagonal()
,
tfb_transpose()
,
tfb_weibull_cdf()
,
tfb_weibull()