Skip to contents

This callbacks replaces the model's weight values with the values of the optimizer's EMA weights (the exponential moving average of the past model weights values, implementing "Polyak averaging") before model evaluation, and restores the previous weights after evaluation.

The SwapEMAWeights callback is to be used in conjunction with an optimizer that sets use_ema = TRUE.

Note that the weights are swapped in-place in order to save memory. The behavior is undefined if you modify the EMA weights or model weights in other callbacks.

Usage

callback_swap_ema_weights(swap_on_epoch = FALSE)

Arguments

swap_on_epoch

Whether to perform swapping at on_epoch_begin() and on_epoch_end(). This is useful if you want to use EMA weights for other callbacks such as callback_model_checkpoint(). Defaults to FALSE.

Value

A Callback instance that can be passed to fit.keras.src.models.model.Model().

Examples

# Remember to set `use_ema=TRUE` in the optimizer
optimizer <- optimizer_sgd(use_ema = TRUE)
model |> compile(optimizer = optimizer, loss = ..., metrics = ...)

# Metrics will be computed with EMA weights
model |> fit(X_train, Y_train,
             callbacks = c(callback_swap_ema_weights()))

# If you want to save model checkpoint with EMA weights, you can set
# `swap_on_epoch=TRUE` and place ModelCheckpoint after SwapEMAWeights.
model |> fit(
  X_train, Y_train,
  callbacks = c(
    callback_swap_ema_weights(swap_on_epoch = TRUE),
    callback_model_checkpoint(...)
  )
)