Skip to contents

You could use this class to quickly build a mean metric from a function. The function needs to have the signature fn(y_true, y_pred) and return a per-sample loss array. metric_mean_wrapper$result() will return the average metric value across all samples seen so far.

For example:

mse <- function(y_true, y_pred) {
  (y_true - y_pred)^2
}

mse_metric <- metric_mean_wrapper(fn = mse)
mse_metric$update_state(c(0, 1), c(1, 1))
mse_metric$result()

## tf.Tensor(0.5, shape=(), dtype=float32)

Usage

metric_mean_wrapper(..., fn, name = NULL, dtype = NULL)

Arguments

...

Keyword arguments to pass on to fn.

fn

The metric function to wrap, with signature fn(y_true, y_pred).

name

(Optional) string name of the metric instance.

dtype

(Optional) data type of the metric result.

Value

a Metric instance is returned. The Metric instance can be passed directly to compile(metrics = ), or used as a standalone object. See ?Metric for example usage.