RematScope

[source]

RematScope class

keras.RematScope(mode="full", output_size_threshold=1024, layer_names=None)

A context manager for enabling rematerialization in Keras.

Rematerialization (gradient checkpointing) trades memory for computation by recomputing intermediate activations during the backward pass. This is particularly useful for training large models or large batch sizes within limited memory constraints.

This should be used when initializing the layer (e.g., layer(input)). Rematerialization applies at execution time, not at creation time.

Arguments

  • mode: Rematerialization mode to apply. Options:
    • "full": Apply rematerialization globally to all supported operations.
    • "activations": Apply rematerialization to activations on any layers that contain keras.activations (e.g., Dense(..., activation=relu)).
    • "larger_than": Apply rematerialization to layers with output sizes larger than output_size_threshold.
    • "list_of_layers": Apply rematerialization to a specific list of layer names.
    • None: Disable rematerialization.
  • output_size_threshold: Output size threshold for the "larger_than" mode. Layers producing outputs larger than this threshold will be rematerialized. Default is 1024.
  • layer_names: List of layer names for the "list_of_layers" mode. Default is an empty list.

Examples

Using "list_of_layers" mode:

from keras import RematScope
input_tensor = tf.random.normal((1, 32, 32, 3))
with RematScope(mode="list_of_layers", layer_names=["dense_1",
"conv2d_1"]):
    layer1 = keras.layers.Dense(128, name="dense_1")
    layer2 = keras.layers.Conv2D(64, (3, 3), name="conv2d_1")
    layer3 = keras.layers.Dense(64, name="dense_2")
    # Only layer1 and layer2 will apply rematerialization
    output1 = layer1(input_tensor)
    output2 = layer2(output1)
    output3 = layer3(output2)

Using "larger_than" mode with a specific output size threshold:

with RematScope(mode="larger_than", output_size_threshold=2048):
    layer = keras.layers.Conv2D(64, (3, 3))
    output = layer(input_tensor)  # Conv2D outputs larger than 2048

Nested scopes for fine-grained control:

with RematScope(mode="full"):
    # Create layers
    layer1 = keras.layers.Dense(128, activation='relu')
    output1 = layer1(input_tensor)  # layer1 is fully rematerialized
    with RematScope(mode="larger_than", output_size_threshold=512):
        layer2 = keras.layers.Conv2D(32, (3, 3))
        output2 = layer2(output1) # layer2 is conditionally rematerialized
        # if output > 512