This bijector implements a continuous dynamics transformation parameterized by a differential equation, where initial and terminal conditions correspond to domain (X) and image (Y) i.e.

tfb_ffjord(
  state_time_derivative_fn,
  ode_solve_fn = NULL,
  trace_augmentation_fn = tfp$bijectors$ffjord$trace_jacobian_hutchinson,
  initial_time = 0,
  final_time = 1,
  validate_args = FALSE,
  dtype = tf$float32,
  name = "ffjord"
)

Arguments

state_time_derivative_fn

function taking arguments time (a scalar representing time) and state (a Tensor representing the state at given time) returning the time derivative of the state at given time.

ode_solve_fn

function taking arguments ode_fn (same as state_time_derivative_fn above), initial_time (a scalar representing the initial time of integration), initial_state (a Tensor of floating dtype represents the initial state) and solution_times (1D Tensor of floating dtype representing time at which to obtain the solution) returning a Tensor of shape [time_axis, initial_state$shape]. Will take [final_time] as the solution_times argument and state_time_derivative_fn as ode_fn argument. If NULL a DormandPrince solver from tfp$math$ode is used. Default value: NULL

trace_augmentation_fn

function taking arguments ode_fn ( function same as state_time_derivative_fn above), state_shape (TensorShape of a the state), dtype (same as dtype of the state) and returning a function taking arguments time (a scalar representing the time at which the function is evaluted), state (a Tensor representing the state at given time) that computes a tuple (ode_fn(time, state), jacobian_trace_estimation). jacobian_trace_estimation should represent trace of the jacobian of ode_fn with respect to state. state_time_derivative_fn will be passed as ode_fn argument. Default value: tfp$bijectors$ffjord$trace_jacobian_hutchinson

initial_time

Scalar float representing time to which the x value of the bijector corresponds to. Passed as initial_time to ode_solve_fn. For default solver can be float or floating scalar Tensor. Default value: 0.

final_time

Scalar float representing time to which the y value of the bijector corresponds to. Passed as solution_times to ode_solve_fn. For default solver can be float or floating scalar Tensor. Default value: 1.

validate_args

Logical, default FALSE. Whether to validate input with asserts. If validate_args is FALSE, and the inputs are invalid, correct behavior is not guaranteed.

dtype

tf$DType to prefer when converting args to Tensors. Else, we fall back to a common dtype inferred from the args, finally falling back to float32.

name

name prefixed to Ops created by this class.

Value

a bijector instance.

Details

d/dt[state(t)] = state_time_derivative_fn(t, state(t))
state(initial_time) = X
state(final_time) = Y

For this transformation the value of log_det_jacobian follows another differential equation, reducing it to computation of the trace of the jacobian along the trajectory

state_time_derivative = state_time_derivative_fn(t, state(t))
d/dt[log_det_jac(t)] = Tr(jacobian(state_time_derivative, state(t)))

FFJORD constructor takes two functions ode_solve_fn and trace_augmentation_fn arguments that customize integration of the differential equation and trace estimation.

Differential equation integration is performed by a call to ode_solve_fn.

Custom ode_solve_fn must accept the following arguments:

  • ode_fn(time, state): Differential equation to be solved.

  • initial_time: Scalar float or floating Tensor representing the initial time.

  • initial_state: Floating Tensor representing the initial state.

  • solution_times: 1D floating Tensor of solution times.

And return a Tensor of shape [solution_times$shape, initial_state$shape] representing state values evaluated at solution_times. In addition ode_solve_fn must support nested structures. For more details see the interface of tfp$math$ode$Solver$solve().

Trace estimation is computed simultaneously with state_time_derivative using augmented_state_time_derivative_fn that is generated by trace_augmentation_fn. trace_augmentation_fn takes state_time_derivative_fn, state.shape and state.dtype arguments and returns a augmented_state_time_derivative_fn callable that computes both state_time_derivative and unreduced trace_estimation.

Custom ode_solve_fn and trace_augmentation_fn examples:

# custom_solver_fn: `function(f, t_initial, t_solutions, y_initial, ...)`
# ... : Additional arguments to pass to custom_solver_fn.
ode_solve_fn <- function(ode_fn, initial_time, initial_state, solution_times) {
  custom_solver_fn(ode_fn, initial_time, solution_times, initial_state, ...)
}
ffjord <- tfb_ffjord(state_time_derivative_fn, ode_solve_fn = ode_solve_fn)
# state_time_derivative_fn: `function(time, state)`
# trace_jac_fn: `function(time, state)` unreduced jacobian trace function
trace_augmentation_fn <- function(ode_fn, state_shape, state_dtype) {
  augmented_ode_fn <- function(time, state) {
    list(ode_fn(time, state), trace_jac_fn(time, state))
  }
augmented_ode_fn
}
ffjord <- tfb_ffjord(state_time_derivative_fn, trace_augmentation_fn = trace_augmentation_fn)

For more details on FFJORD and continous normalizing flows see Chen et al. (2018), Grathwol et al. (2018).

References

See also