remat
functionkeras.remat(f)
Applies rematerialization to a function or layer for memory optimization.
Rematerialization is a memory optimization technique that trades off computation for memory. Instead of storing intermediate results (e.g. activations) for backpropagation, they are recomputed during the backward pass. This reduces peak memory usage at the cost of increased computation time, allowing the training of larger models or using larger batch sizes within the same memory constraints.
Arguments
Returns
A wrapped function that applies rematerialization. The returned function defines a custom gradient, ensuring that during the backward pass, the forward computation is recomputed as needed.
Example
from keras import Model
class CustomRematLayer(layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.remat_function = remat(self.intermediate_function)
def intermediate_function(self, x):
for _ in range(2):
x = x + x * 0.1 # Simple scaled transformation
return x
def call(self, inputs):
return self.remat_function(inputs)
# Define a simple model using the custom layer
inputs = layers.Input(shape=(4,))
x = layers.Dense(4, activation="relu")(inputs)
x = CustomRematLayer()(x) # Custom layer with rematerialization
outputs = layers.Dense(1)(x)
# Create and compile the model
model = Model(inputs=inputs, outputs=outputs)
model.compile(optimizer="sgd", loss="mse")