Author: Khalid Salama
Date created: 2021/04/30
Last modified: 2023/12/30
Description: Implementing the Perceiver model for image classification.
View in Colab • GitHub source
This example implements the Perceiver: General Perception with Iterative Attention model by Andrew Jaegle et al. for image classification, and demonstrates it on the CIFAR-100 dataset.
The Perceiver model leverages an asymmetric attention mechanism to iteratively distill inputs into a tight latent bottleneck, allowing it to scale to handle very large inputs.
In other words: let's assume that your input data array (e.g. image) has M
elements (i.e. patches), where M
is large.
In a standard Transformer model, a self-attention operation is performed for the M
elements.
The complexity of this operation is O(M^2)
.
However, the Perceiver model creates a latent array of size N
elements, where N << M
,
and performs two operations iteratively:
O(M.N)
.O(N^2)
.This example requires Keras 3.0 or higher.
import keras
from keras import layers, activations, ops
num_classes = 100
input_shape = (32, 32, 3)
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")
x_train shape: (50000, 32, 32, 3) - y_train shape: (50000, 1)
x_test shape: (10000, 32, 32, 3) - y_test shape: (10000, 1)
learning_rate = 0.001
weight_decay = 0.0001
batch_size = 64
num_epochs = 2 # You should actually use 50 epochs!
dropout_rate = 0.2
image_size = 64 # We'll resize input images to this size.
patch_size = 2 # Size of the patches to be extract from the input images.
num_patches = (image_size // patch_size) ** 2 # Size of the data array.
latent_dim = 256 # Size of the latent array.
projection_dim = 256 # Embedding size of each element in the data and latent arrays.
num_heads = 8 # Number of Transformer heads.
ffn_units = [
projection_dim,
projection_dim,
] # Size of the Transformer Feedforward network.
num_transformer_blocks = 4
num_iterations = 2 # Repetitions of the cross-attention and Transformer modules.
classifier_units = [
projection_dim,
num_classes,
] # Size of the Feedforward network of the final classifier.
print(f"Image size: {image_size} X {image_size} = {image_size ** 2}")
print(f"Patch size: {patch_size} X {patch_size} = {patch_size ** 2} ")
print(f"Patches per image: {num_patches}")
print(f"Elements per patch (3 channels): {(patch_size ** 2) * 3}")
print(f"Latent array shape: {latent_dim} X {projection_dim}")
print(f"Data array shape: {num_patches} X {projection_dim}")
Image size: 64 X 64 = 4096
Patch size: 2 X 2 = 4
Patches per image: 1024
Elements per patch (3 channels): 12
Latent array shape: 256 X 256
Data array shape: 1024 X 256
Note that, in order to use each pixel as an individual input in the data array,
set patch_size
to 1.
data_augmentation = keras.Sequential(
[
layers.Normalization(),
layers.Resizing(image_size, image_size),
layers.RandomFlip("horizontal"),
layers.RandomZoom(height_factor=0.2, width_factor=0.2),
],
name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)
def create_ffn(hidden_units, dropout_rate):
ffn_layers = []
for units in hidden_units[:-1]:
ffn_layers.append(layers.Dense(units, activation=activations.gelu))
ffn_layers.append(layers.Dense(units=hidden_units[-1]))
ffn_layers.append(layers.Dropout(dropout_rate))
ffn = keras.Sequential(ffn_layers)
return ffn
class Patches(layers.Layer):
def __init__(self, patch_size):
super().__init__()
self.patch_size = patch_size
def call(self, images):
batch_size = ops.shape(images)[0]
patches = ops.image.extract_patches(
image=images,
size=(self.patch_size, self.patch_size),
strides=(self.patch_size, self.patch_size),
dilation_rate=1,
padding="valid",
)
patch_dims = patches.shape[-1]
patches = ops.reshape(patches, [batch_size, -1, patch_dims])
return patches
The PatchEncoder
layer will linearly transform a patch by projecting it into
a vector of size latent_dim
. In addition, it adds a learnable position embedding
to the projected vector.
Note that the orginal Perceiver paper uses the Fourier feature positional encodings.
class PatchEncoder(layers.Layer):
def __init__(self, num_patches, projection_dim):
super().__init__()
self.num_patches = num_patches
self.projection = layers.Dense(units=projection_dim)
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
def call(self, patches):
positions = ops.arange(start=0, stop=self.num_patches, step=1)
encoded = self.projection(patches) + self.position_embedding(positions)
return encoded
The Perceiver consists of two modules: a cross-attention module and a standard Transformer with self-attention.
The cross-attention expects a (latent_dim, projection_dim)
latent array,
and the (data_dim, projection_dim)
data array as inputs,
to produce a (latent_dim, projection_dim)
latent array as an output.
To apply cross-attention, the query
vectors are generated from the latent array,
while the key
and value
vectors are generated from the encoded image.
Note that the data array in this example is the image,
where the data_dim
is set to the num_patches
.
def create_cross_attention_module(
latent_dim, data_dim, projection_dim, ffn_units, dropout_rate
):
inputs = {
# Recieve the latent array as an input of shape [1, latent_dim, projection_dim].
"latent_array": layers.Input(
shape=(latent_dim, projection_dim), name="latent_array"
),
# Recieve the data_array (encoded image) as an input of shape [batch_size, data_dim, projection_dim].
"data_array": layers.Input(shape=(data_dim, projection_dim), name="data_array"),
}
# Apply layer norm to the inputs
latent_array = layers.LayerNormalization(epsilon=1e-6)(inputs["latent_array"])
data_array = layers.LayerNormalization(epsilon=1e-6)(inputs["data_array"])
# Create query tensor: [1, latent_dim, projection_dim].
query = layers.Dense(units=projection_dim)(latent_array)
# Create key tensor: [batch_size, data_dim, projection_dim].
key = layers.Dense(units=projection_dim)(data_array)
# Create value tensor: [batch_size, data_dim, projection_dim].
value = layers.Dense(units=projection_dim)(data_array)
# Generate cross-attention outputs: [batch_size, latent_dim, projection_dim].
attention_output = layers.Attention(use_scale=True, dropout=0.1)(
[query, key, value], return_attention_scores=False
)
# Skip connection 1.
attention_output = layers.Add()([attention_output, latent_array])
# Apply layer norm.
attention_output = layers.LayerNormalization(epsilon=1e-6)(attention_output)
# Apply Feedforward network.
ffn = create_ffn(hidden_units=ffn_units, dropout_rate=dropout_rate)
outputs = ffn(attention_output)
# Skip connection 2.
outputs = layers.Add()([outputs, attention_output])
# Create the Keras model.
model = keras.Model(inputs=inputs, outputs=outputs)
return model
The Transformer expects the output latent vector from the cross-attention module
as an input, applies multi-head self-attention to its latent_dim
elements,
followed by feedforward network, to produce another (latent_dim, projection_dim)
latent array.
def create_transformer_module(
latent_dim,
projection_dim,
num_heads,
num_transformer_blocks,
ffn_units,
dropout_rate,
):
# input_shape: [1, latent_dim, projection_dim]
inputs = layers.Input(shape=(latent_dim, projection_dim))
x0 = inputs
# Create multiple layers of the Transformer block.
for _ in range(num_transformer_blocks):
# Apply layer normalization 1.
x1 = layers.LayerNormalization(epsilon=1e-6)(x0)
# Create a multi-head self-attention layer.
attention_output = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
)(x1, x1)
# Skip connection 1.
x2 = layers.Add()([attention_output, x0])
# Apply layer normalization 2.
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
# Apply Feedforward network.
ffn = create_ffn(hidden_units=ffn_units, dropout_rate=dropout_rate)
x3 = ffn(x3)
# Skip connection 2.
x0 = layers.Add()([x3, x2])
# Create the Keras model.
model = keras.Model(inputs=inputs, outputs=x0)
return model
The Perceiver model repeats the cross-attention and Transformer modules
num_iterations
times—with shared weights and skip connections—to allow
the latent array to iteratively extract information from the input image as it is needed.
class Perceiver(keras.Model):
def __init__(
self,
patch_size,
data_dim,
latent_dim,
projection_dim,
num_heads,
num_transformer_blocks,
ffn_units,
dropout_rate,
num_iterations,
classifier_units,
):
super().__init__()
self.latent_dim = latent_dim
self.data_dim = data_dim
self.patch_size = patch_size
self.projection_dim = projection_dim
self.num_heads = num_heads
self.num_transformer_blocks = num_transformer_blocks
self.ffn_units = ffn_units
self.dropout_rate = dropout_rate
self.num_iterations = num_iterations
self.classifier_units = classifier_units
def build(self, input_shape):
# Create latent array.
self.latent_array = self.add_weight(
shape=(self.latent_dim, self.projection_dim),
initializer="random_normal",
trainable=True,
)
# Create patching module.war
self.patch_encoder = PatchEncoder(self.data_dim, self.projection_dim)
# Create cross-attenion module.
self.cross_attention = create_cross_attention_module(
self.latent_dim,
self.data_dim,
self.projection_dim,
self.ffn_units,
self.dropout_rate,
)
# Create Transformer module.
self.transformer = create_transformer_module(
self.latent_dim,
self.projection_dim,
self.num_heads,
self.num_transformer_blocks,
self.ffn_units,
self.dropout_rate,
)
# Create global average pooling layer.
self.global_average_pooling = layers.GlobalAveragePooling1D()
# Create a classification head.
self.classification_head = create_ffn(
hidden_units=self.classifier_units, dropout_rate=self.dropout_rate
)
super().build(input_shape)
def call(self, inputs):
# Augment data.
augmented = data_augmentation(inputs)
# Create patches.
patches = self.patcher(augmented)
# Encode patches.
encoded_patches = self.patch_encoder(patches)
# Prepare cross-attention inputs.
cross_attention_inputs = {
"latent_array": ops.expand_dims(self.latent_array, 0),
"data_array": encoded_patches,
}
# Apply the cross-attention and the Transformer modules iteratively.
for _ in range(self.num_iterations):
# Apply cross-attention from the latent array to the data array.
latent_array = self.cross_attention(cross_attention_inputs)
# Apply self-attention Transformer to the latent array.
latent_array = self.transformer(latent_array)
# Set the latent array of the next iteration.
cross_attention_inputs["latent_array"] = latent_array
# Apply global average pooling to generate a [batch_size, projection_dim] repesentation tensor.
representation = self.global_average_pooling(latent_array)
# Generate logits.
logits = self.classification_head(representation)
return logits
def run_experiment(model):
# Create ADAM instead of LAMB optimizer with weight decay. (LAMB isn't supported yet)
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
# Compile the model.
model.compile(
optimizer=optimizer,
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[
keras.metrics.SparseCategoricalAccuracy(name="acc"),
keras.metrics.SparseTopKCategoricalAccuracy(5, name="top5-acc"),
],
)
# Create a learning rate scheduler callback.
reduce_lr = keras.callbacks.ReduceLROnPlateau(
monitor="val_loss", factor=0.2, patience=3
)
# Create an early stopping callback.
early_stopping = keras.callbacks.EarlyStopping(
monitor="val_loss", patience=15, restore_best_weights=True
)
# Fit the model.
history = model.fit(
x=x_train,
y=y_train,
batch_size=batch_size,
epochs=num_epochs,
validation_split=0.1,
callbacks=[early_stopping, reduce_lr],
)
_, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
# Return history to plot learning curves.
return history
Note that training the perceiver model with the current settings on a V100 GPUs takes around 200 seconds.
perceiver_classifier = Perceiver(
patch_size,
num_patches,
latent_dim,
projection_dim,
num_heads,
num_transformer_blocks,
ffn_units,
dropout_rate,
num_iterations,
classifier_units,
)
history = run_experiment(perceiver_classifier)
Test accuracy: 0.91%
Test top 5 accuracy: 5.2%
After 40 epochs, the Perceiver model achieves around 53% accuracy and 81% top-5 accuracy on the test data.
As mentioned in the ablations of the Perceiver paper, you can obtain better results by increasing the latent array size, increasing the (projection) dimensions of the latent array and data array elements, increasing the number of blocks in the Transformer module, and increasing the number of iterations of applying the cross-attention and the latent Transformer modules. You may also try to increase the size the input images and use different patch sizes.
The Perceiver benefits from inceasing the model size. However, larger models needs bigger accelerators to fit in and train efficiently. This is why in the Perceiver paper they used 32 TPU cores to run the experiments.