When using Monte Carlo approximation (e.g., `use_exact = FALSE`

), it is presumed that the input
distribution's concretization (i.e., `tf$convert_to_tensor(distribution)`

) corresponds to a random
sample. To override this behavior, set test_points_fn.

layer_kl_divergence_regularizer(
object,
distribution_b,
use_exact_kl = FALSE,
test_points_reduce_axis = NULL,
test_points_fn = tf$convert_to_tensor,
weight = NULL,
...
)

## Arguments

object |
Model or layer object |

distribution_b |
Distribution instance corresponding to b as in `KL[a, b]` .
The previous layer's output is presumed to be a Distribution instance and is a. |

use_exact_kl |
Logical indicating if KL divergence should be
calculated exactly via `tfp$distributions$kl_divergence` or via Monte Carlo approximation.
Default value: FALSE. |

test_points_reduce_axis |
Integer vector or scalar representing dimensions
over which to reduce_mean while calculating the Monte Carlo approximation of the KL divergence.
As is with all tf$reduce_* ops, NULL means reduce over all dimensions;
() means reduce over none of them. Default value: () (i.e., no reduction). |

test_points_fn |
A callable taking a `tfp$distributions$Distribution` instance and returning a tensor
used for random test points to approximate the KL divergence.
Default value: tf$convert_to_tensor. |

weight |
Multiplier applied to the calculated KL divergence for each Keras batch member.
Default value: NULL (i.e., do not weight each batch member). |

... |
Additional arguments passed to `args` of `keras::create_layer` . |

## Value

a Keras layer

## See also

For an example how to use in a Keras model, see `layer_independent_normal()`

.

Other distribution_layers:
`layer_categorical_mixture_of_one_hot_categorical()`

,
`layer_distribution_lambda()`

,
`layer_independent_bernoulli()`

,
`layer_independent_logistic()`

,
`layer_independent_normal()`

,
`layer_independent_poisson()`

,
`layer_kl_divergence_add_loss()`

,
`layer_mixture_logistic()`

,
`layer_mixture_normal()`

,
`layer_mixture_same_family()`

,
`layer_multivariate_normal_tri_l()`

,
`layer_one_hot_categorical()`