Skip to contents

This layer enables the use of Flax components in the form of flax.linen.Module instances within Keras when using JAX as the backend for Keras.

The module method to use for the forward pass can be specified via the method argument and is __call__ by default. This method must take the following arguments with these exact names:

  • self if the method is bound to the module, which is the case for the default of __call__, and module otherwise to pass the module.

  • inputs: the inputs to the model, a JAX array or a PyTree of arrays.

  • training (optional): an argument specifying if we're in training mode or inference mode, TRUE is passed in training mode.

FlaxLayer handles the non-trainable state of your model and required RNGs automatically. Note that the mutable parameter of flax.linen.Module.apply() is set to DenyList(["params"]), therefore making the assumption that all the variables outside of the "params" collection are non-trainable weights.

This example shows how to create a FlaxLayer from a Flax Module with the default __call__ method and no training argument:

# keras3::use_backend("jax")
# py_install("flax", "r-keras")

if(config_backend() == "jax" &&
   reticulate::py_module_available("flax")) {

flax <- import("flax")

MyFlaxModule(flax$linen$Module) %py_class% {
  `__call__` <- flax$linen$compact(\(self, inputs) {
    inputs |>
      (flax$linen$Conv(features = 32L, kernel_size = tuple(3L, 3L)))() |>
      flax$linen$relu() |>
      flax$linen$avg_pool(window_shape = tuple(2L, 2L),
                          strides = tuple(2L, 2L)) |>
      # flatten all except batch_size axis
      (\(x) x$reshape(tuple(x$shape[[1]], -1L)))() |>
      (flax$linen$Dense(features = 200L))() |>
      flax$linen$relu() |>
      (flax$linen$Dense(features = 10L))() |>
      flax$linen$softmax()
  })
}

# typical usage:
input <- keras_input(c(28, 28, 3))
output <- input |>
  layer_flax_module_wrapper(MyFlaxModule())

model <- keras_model(input, output)

# to instantiate the layer before composing:
flax_module <- MyFlaxModule()
keras_layer <- layer_flax_module_wrapper(module = flax_module)

input <- keras_input(c(28, 28, 3))
output <- input |>
  keras_layer()

model <- keras_model(input, output)

}

This example shows how to wrap the module method to conform to the required signature. This allows having multiple input arguments and a training argument that has a different name and values. This additionally shows how to use a function that is not bound to the module.

flax <- import("flax")

MyFlaxModule(flax$linen$Module) \%py_class\% {
  forward <-
    flax$linen$compact(\(self, inputs1, input2, deterministic) {
      # do work ....
      outputs # return
    })
}

my_flax_module_wrapper <- function(module, inputs, training) {
  c(input1, input2) \%<-\% inputs
  module$forward(input1, input2,!training)
}

flax_module <- MyFlaxModule()
keras_layer <- layer_flax_module_wrapper(module = flax_module,
                                         method = my_flax_module_wrapper)

Usage

layer_flax_module_wrapper(object, module, method = NULL, variables = NULL, ...)

Arguments

object

Object to compose the layer with. A tensor, array, or sequential model.

module

An instance of flax.linen.Module or subclass.

method

The method to call the model. This is generally a method in the Module. If not provided, the __call__ method is used. method can also be a function not defined in the Module, in which case it must take the Module as the first argument. It is used for both Module.init and Module.apply. Details are documented in the method argument of flax.linen.Module.apply().

variables

A dict (named R list) containing all the variables of the module in the same format as what is returned by flax.linen.Module.init(). It should contain a "params" key and, if applicable, other keys for collections of variables for non-trainable state. This allows passing trained parameters and learned non-trainable state or controlling the initialization. If NULL is passed, the module's init function is called at build time to initialize the variables of the model.

...

For forward/backward compatability.

Value

The return value depends on the value provided for the first argument. If object is:

  • a keras_model_sequential(), then the layer is added to the sequential model (which is modified in place). To enable piping, the sequential model is also returned, invisibly.

  • a keras_input(), then the output tensor from calling layer(input) is returned.

  • NULL or missing, then a Layer instance is returned.

See also

Other wrapping layers:
layer_jax_model_wrapper()
layer_torch_module_wrapper()

Other layers:
Layer()
layer_activation()
layer_activation_elu()
layer_activation_leaky_relu()
layer_activation_parametric_relu()
layer_activation_relu()
layer_activation_softmax()
layer_activity_regularization()
layer_add()
layer_additive_attention()
layer_alpha_dropout()
layer_attention()
layer_average()
layer_average_pooling_1d()
layer_average_pooling_2d()
layer_average_pooling_3d()
layer_batch_normalization()
layer_bidirectional()
layer_category_encoding()
layer_center_crop()
layer_concatenate()
layer_conv_1d()
layer_conv_1d_transpose()
layer_conv_2d()
layer_conv_2d_transpose()
layer_conv_3d()
layer_conv_3d_transpose()
layer_conv_lstm_1d()
layer_conv_lstm_2d()
layer_conv_lstm_3d()
layer_cropping_1d()
layer_cropping_2d()
layer_cropping_3d()
layer_dense()
layer_depthwise_conv_1d()
layer_depthwise_conv_2d()
layer_discretization()
layer_dot()
layer_dropout()
layer_einsum_dense()
layer_embedding()
layer_feature_space()
layer_flatten()
layer_gaussian_dropout()
layer_gaussian_noise()
layer_global_average_pooling_1d()
layer_global_average_pooling_2d()
layer_global_average_pooling_3d()
layer_global_max_pooling_1d()
layer_global_max_pooling_2d()
layer_global_max_pooling_3d()
layer_group_normalization()
layer_group_query_attention()
layer_gru()
layer_hashed_crossing()
layer_hashing()
layer_identity()
layer_integer_lookup()
layer_jax_model_wrapper()
layer_lambda()
layer_layer_normalization()
layer_lstm()
layer_masking()
layer_max_pooling_1d()
layer_max_pooling_2d()
layer_max_pooling_3d()
layer_maximum()
layer_mel_spectrogram()
layer_minimum()
layer_multi_head_attention()
layer_multiply()
layer_normalization()
layer_permute()
layer_random_brightness()
layer_random_contrast()
layer_random_crop()
layer_random_flip()
layer_random_rotation()
layer_random_translation()
layer_random_zoom()
layer_repeat_vector()
layer_rescaling()
layer_reshape()
layer_resizing()
layer_rnn()
layer_separable_conv_1d()
layer_separable_conv_2d()
layer_simple_rnn()
layer_spatial_dropout_1d()
layer_spatial_dropout_2d()
layer_spatial_dropout_3d()
layer_spectral_normalization()
layer_string_lookup()
layer_subtract()
layer_text_vectorization()
layer_tfsm()
layer_time_distributed()
layer_torch_module_wrapper()
layer_unit_normalization()
layer_upsampling_1d()
layer_upsampling_2d()
layer_upsampling_3d()
layer_zero_padding_1d()
layer_zero_padding_2d()
layer_zero_padding_3d()
rnn_cell_gru()
rnn_cell_lstm()
rnn_cell_simple()
rnn_cells_stack()