Author: Aritra Roy Gosthipaty, Suvaditya Mukherjee
Date created: 2023/03/12
Last modified: 2023/03/12
Description: Image Classification with Temporal Latent Bottleneck Networks.
View in Colab β’ GitHub source
A simple Recurrent Neural Network (RNN) displays a strong inductive bias towards learning
temporally compressed representations. Equation 1 shows the recurrence formula,
where h_t
is the compressed representation (a single vector) of the entire input
sequence x
.
Equation 1: The recurrence equation. (Source: Aritra and Suvaditya) |
On the other hand, Transformers (Vaswani et. al) have little inductive bias towards learning temporally compressed representations. Transformer has achieved SoTA results in Natural Language Processing (NLP) and Vision tasks with its pairwise attention mechanism.
While the Transformer has the ability to attend to different sections of the input sequence, the computation of attention is quadratic in nature.
Didolkar et. al argue that having a more compressed representation of a sequence may be beneficial for generalization, as it can be easily re-used and re-purposed with fewer irrelevant details. While compression is good, they also notice that too much of it can harm expressiveness.
The authors propose a solution that divides computation into two streams. A slow stream that is recurrent in nature and a fast stream that is parameterized as a Transformer. While this method has the novelty of introducing different processing streams in order to preserve and process latent states, it has parallels drawn in other works like the Perceiver Mechanism (by Jaegle et. al.) and Grounded Language Learning Fast and Slow (by Hill et. al.).
The following example explores how we can make use of the new Temporal Latent Bottleneck
mechanism to perform image classification on the CIFAR-10 dataset. We implement this
model by making a custom RNNCell
implementation in order to make a performant and
vectorized design.
Note: This example makes use of TensorFlow 2.12.0
, which must be installed into our
system
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import mixed_precision
from tensorflow.keras.optimizers import AdamW
import random
from matplotlib import pyplot as plt
# Set seed for reproducibility.
keras.utils.set_random_seed(42)
AUTO = tf.data.AUTOTUNE
We set a few configuration parameters that are needed within the pipeline we have designed. The current parameters are for use with the CIFAR10 dataset.
The model also supports mixed-precision
settings, which would quantize the model to use
16-bit
float numbers where it can, while keeping some parameters in 32-bit
as needed
for numerical stability. This brings performance benefits as the footprint of the model
decreases significantly while bringing speed boosts at inference-time.
config = {
"mixed_precision": True,
"dataset": "cifar10",
"train_slice": 40_000,
"batch_size": 2048,
"buffer_size": 2048 * 2,
"input_shape": [32, 32, 3],
"image_size": 48,
"num_classes": 10,
"learning_rate": 1e-4,
"weight_decay": 1e-4,
"epochs": 30,
"patch_size": 4,
"embed_dim": 64,
"chunk_size": 8,
"r": 2,
"num_layers": 4,
"ffn_drop": 0.2,
"attn_drop": 0.2,
"num_heads": 1,
}
if config["mixed_precision"]:
policy = mixed_precision.Policy("mixed_float16")
mixed_precision.set_global_policy(policy)
INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPU will likely run quickly with dtype policy mixed_float16 as it has compute capability of at least 7.0. Your GPU: NVIDIA A100-PCIE-40GB, compute capability 8.0
We are going to use the CIFAR10 dataset for running our experiments. This dataset
contains a training set of 50,000
images for 10
classes with the standard image size
of (32, 32, 3)
.
It also has a separate set of 10,000
images with similar characteristics. More
information about the dataset may be found at the official site for the dataset as well
as keras.datasets.cifar10
API reference
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
(x_train, y_train), (x_val, y_val) = (
(x_train[: config["train_slice"]], y_train[: config["train_slice"]]),
(x_train[config["train_slice"] :], y_train[config["train_slice"] :]),
)
We define separate pipelines for performing image augmentation on our data. This step is important to make the model more robust to changes, helping it generalize better. The preprocessing and augmentation steps we perform are as follows:
Rescaling
(training, test): This step is performed to normalize all image pixel
values from the [0,255]
range to [0,1)
. This helps in maintaining numerical stability
later ahead during training.Resizing
(training, test): We resize the image from it's original size of (32, 32) to
(52, 52). This is done to account for the Random Crop, as well as comply with the
specifications of the data given in the paper.RandomCrop
(training): This layer randomly selects a crop/sub-region of the image
with size (48, 48)
.RandomFlip
(training): This layer randomly flips all the images horizontally,
keeping image sizes the same.# Build the `train` augmentation pipeline.
train_augmentation = keras.Sequential(
[
layers.Rescaling(1 / 255.0, dtype="float32"),
layers.Resizing(
config["input_shape"][0] + 20,
config["input_shape"][0] + 20,
dtype="float32",
),
layers.RandomCrop(config["image_size"], config["image_size"], dtype="float32"),
layers.RandomFlip("horizontal", dtype="float32"),
],
name="train_data_augmentation",
)
# Build the `val` and `test` data pipeline.
test_augmentation = keras.Sequential(
[
layers.Rescaling(1 / 255.0, dtype="float32"),
layers.Resizing(config["image_size"], config["image_size"], dtype="float32"),
],
name="test_data_augmentation",
)
# We define functions in place of simple lambda functions to run through the
# [`keras.Sequential`](/api/models/sequential#sequential-class)in order to solve this warning:
# (https://github.com/tensorflow/tensorflow/issues/56089)
def train_map_fn(image, label):
return train_augmentation(image), label
def test_map_fn(image, label):
return test_augmentation(image), label
tf.data.Dataset
objectnp.ndarray
instance of the datasets and move them into a
tf.data.Dataset
instance.map()
.shuffle()
.batch()
.prefetch()
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = (
train_ds.map(train_map_fn, num_parallel_calls=AUTO)
.shuffle(config["buffer_size"])
.batch(config["batch_size"], num_parallel_calls=AUTO)
.prefetch(AUTO)
)
val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_ds = (
val_ds.map(test_map_fn, num_parallel_calls=AUTO)
.batch(config["batch_size"], num_parallel_calls=AUTO)
.prefetch(AUTO)
)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_ds = (
test_ds.map(test_map_fn, num_parallel_calls=AUTO)
.batch(config["batch_size"], num_parallel_calls=AUTO)
.prefetch(AUTO)
)
An excerpt from the paper:
In the brain, short-term and long-term memory have developed in a specialized way. Short-term memory is allowed to change very quickly to react to immediate sensory inputs and perception. By contrast, long-term memory changes slowly, is highly selective and involves repeated consolidation.
Inspired from the short-term and long-term memory the authors introduce the fast stream and slow stream computation. The fast stream has a short-term memory with a high capacity that reacts quickly to sensory input (Transformers). The slow stream has long-term memory which updates at a slower rate and summarizes the most relevant information (Recurrence).
To implement this idea we need to:
The fast and slow stream induce what is called information asymmetry. The two streams interact with each other through a bottleneck of attention. Figure 1 shows the architecture of the model.
Figure 1: Architecture of the model. (Source: https://arxiv.org/abs/2205.14794) |
A PyTorch-style pseudocode is also proposed by the authors as shown in Algorithm 1.
Algorithm 1: PyTorch style pseudocode. (Source: https://arxiv.org/abs/2205.14794) |
PatchEmbedding
layerThis custom keras.layers.Layer
is useful for generating patches from the image and
transform them into a higher-dimensional embedding space using keras.layers.Embedding
.
The patching operation is done using a keras.layers.Conv2D
instance instead of a
traditional tf.image.extract_patches
to allow for vectorization.
Once the patching of images is complete, we reshape the image patches in order to get a flattened representation where the number of dimensions is the embedding dimension. At this stage, we also inject positional information to the tokens.
After we obtain the tokens we chunk them. The chunking operation involves taking fixed-size sequences from the embedding output to create 'chunks', which will then be used as the final input to the model.
class PatchEmbedding(layers.Layer):
"""Image to Patch Embedding.
Args:
image_size (`Tuple[int]`): Size of the input image.
patch_size (`Tuple[int]`): Size of the patch.
embed_dim (`int`): Dimension of the embedding.
chunk_size (`int`): Number of patches to be chunked.
"""
def __init__(
self,
image_size,
patch_size,
embed_dim,
chunk_size,
**kwargs,
):
super().__init__(**kwargs)
# Compute the patch resolution.
patch_resolution = [
image_size[0] // patch_size[0],
image_size[1] // patch_size[1],
]
# Store the parameters.
self.image_size = image_size
self.patch_size = patch_size
self.embed_dim = embed_dim
self.patch_resolution = patch_resolution
self.num_patches = patch_resolution[0] * patch_resolution[1]
# Define the positions of the patches.
self.positions = tf.range(start=0, limit=self.num_patches, delta=1)
# Create the layers.
self.projection = layers.Conv2D(
filters=embed_dim,
kernel_size=patch_size,
strides=patch_size,
name="projection",
)
self.flatten = layers.Reshape(
target_shape=(-1, embed_dim),
name="flatten",
)
self.position_embedding = layers.Embedding(
input_dim=self.num_patches,
output_dim=embed_dim,
name="position_embedding",
)
self.layernorm = keras.layers.LayerNormalization(
epsilon=1e-5,
name="layernorm",
)
self.chunking_layer = layers.Reshape(
target_shape=(self.num_patches // chunk_size, chunk_size, embed_dim),
name="chunking_layer",
)
def call(self, inputs):
# Project the inputs to the embedding dimension.
x = self.projection(inputs)
# Flatten the pathces and add position embedding.
x = self.flatten(x)
x = x + self.position_embedding(self.positions)
# Normalize the embeddings.
x = self.layernorm(x)
# Chunk the tokens.
x = self.chunking_layer(x)
return x
FeedForwardNetwork
LayerThis custom keras.layers.Layer
instance allows us to define a generic FFN along with a
dropout.
class FeedForwardNetwork(layers.Layer):
"""Feed Forward Network.
Args:
dims (`int`): Number of units in FFN.
dropout (`float`): Dropout probability for FFN.
"""
def __init__(self, dims, dropout, **kwargs):
super().__init__(**kwargs)
# Create the layers.
self.ffn = keras.Sequential(
[
layers.Dense(units=4 * dims, activation=tf.nn.gelu),
layers.Dense(units=dims),
layers.Dropout(rate=dropout),
],
name="ffn",
)
self.layernorm = layers.LayerNormalization(
epsilon=1e-5,
name="layernorm",
)
def call(self, inputs):
# Apply the FFN.
x = self.layernorm(inputs)
x = inputs + self.ffn(x)
return x
BaseAttention
layerThis custom keras.layers.Layer
instance is a super
/base
class that wraps a
keras.layers.MultiHeadAttention
layer along with some other components. This gives us
basic common denominator functionality for all the Attention layers/modules in our model.
class BaseAttention(layers.Layer):
"""Base Attention Module.
Args:
num_heads (`int`): Number of attention heads.
key_dim (`int`): Size of each attention head for key.
dropout (`float`): Dropout probability for attention module.
"""
def __init__(self, num_heads, key_dim, dropout, **kwargs):
super().__init__(**kwargs)
self.multi_head_attention = layers.MultiHeadAttention(
num_heads=num_heads,
key_dim=key_dim,
dropout=dropout,
name="mha",
)
self.query_layernorm = layers.LayerNormalization(
epsilon=1e-5,
name="q_layernorm",
)
self.key_layernorm = layers.LayerNormalization(
epsilon=1e-5,
name="k_layernorm",
)
self.value_layernorm = layers.LayerNormalization(
epsilon=1e-5,
name="v_layernorm",
)
self.attention_scores = None
def call(self, input_query, key, value):
# Apply the attention module.
query = self.query_layernorm(input_query)
key = self.key_layernorm(key)
value = self.value_layernorm(value)
(attention_outputs, attention_scores) = self.multi_head_attention(
query=query,
key=key,
value=value,
return_attention_scores=True,
)
# Save the attention scores for later visualization.
self.attention_scores = attention_scores
# Add the input to the attention output.
x = input_query + attention_outputs
return x
Attention
with FeedForwardNetwork
layerThis custom keras.layers.Layer
implementation combines the BaseAttention
and
FeedForwardNetwork
components to develop one block which will be used repeatedly within
the model. This module is highly customizable and flexible, allowing for changes within
the internal layers.
class AttentionWithFFN(layers.Layer):
"""Attention with Feed Forward Network.
Args:
ffn_dims (`int`): Number of units in FFN.
ffn_dropout (`float`): Dropout probability for FFN.
num_heads (`int`): Number of attention heads.
key_dim (`int`): Size of each attention head for key.
attn_dropout (`float`): Dropout probability for attention module.
"""
def __init__(
self,
ffn_dims,
ffn_dropout,
num_heads,
key_dim,
attn_dropout,
**kwargs,
):
super().__init__(**kwargs)
# Create the layers.
self.attention = BaseAttention(
num_heads=num_heads,
key_dim=key_dim,
dropout=attn_dropout,
name="base_attn",
)
self.ffn = FeedForwardNetwork(
dims=ffn_dims,
dropout=ffn_dropout,
name="ffn",
)
self.attention_scores = None
def call(self, query, key, value):
# Apply the attention module.
x = self.attention(query, key, value)
# Save the attention scores for later visualization.
self.attention_scores = self.attention.attention_scores
# Apply the FFN.
x = self.ffn(x)
return x
Algorithm 1 (the pseudocode) depicts recurrence with the help of for loops. Looping
does make the implementation simpler, harming the training time. In this section we wrap
the custom recurrence logic inside of the CustomRecurrentCell
. This custom cell will
then be wrapped with the Keras RNN API
that makes the entire code vectorizable.
This custom cell, implemented as a keras.layers.Layer
, is the integral part of the
logic for the model.
The cell's functionality can be divided into 2 parts:
- Slow Stream (Temporal Latent Bottleneck):
AttentionWithFFN
layer that parses the output of the
previous Slow Stream, an intermediate hidden representation (which is the latent in
Temporal Latent Bottleneck) as the Query, and the output of the latest Fast Stream as Key
and Value. This layer can also be construed as a CrossAttention layer.AttentionWithFFN
layers. This stream consists of
n layers of SelfAttention
and CrossAttention
in a sequential manner.class CustomRecurrentCell(layers.Layer):
"""Custom Recurrent Cell.
Args:
chunk_size (`int`): Number of tokens in a chunk.
r (`int`): One Cross Attention per **r** Self Attention.
num_layers (`int`): Number of layers.
ffn_dims (`int`): Number of units in FFN.
ffn_dropout (`float`): Dropout probability for FFN.
num_heads (`int`): Number of attention heads.
key_dim (`int`): Size of each attention head for key.
attn_dropout (`float`): Dropout probability for attention module.
"""
def __init__(
self,
chunk_size,
r,
num_layers,
ffn_dims,
ffn_dropout,
num_heads,
key_dim,
attn_dropout,
**kwargs,
):
super().__init__(**kwargs)
# Save the arguments.
self.chunk_size = chunk_size
self.r = r
self.num_layers = num_layers
self.ffn_dims = ffn_dims
self.ffn_droput = ffn_dropout
self.num_heads = num_heads
self.key_dim = key_dim
self.attn_dropout = attn_dropout
# Create the state_size and output_size. This is important for
# custom recurrence logic.
self.state_size = tf.TensorShape([chunk_size, ffn_dims])
self.output_size = tf.TensorShape([chunk_size, ffn_dims])
self.get_attention_scores = False
self.attention_scores = []
# Perceptual Module
perceptual_module = list()
for layer_idx in range(num_layers):
perceptual_module.append(
AttentionWithFFN(
ffn_dims=ffn_dims,
ffn_dropout=ffn_dropout,
num_heads=num_heads,
key_dim=key_dim,
attn_dropout=attn_dropout,
name=f"pm_self_attn_{layer_idx}",
)
)
if layer_idx % r == 0:
perceptual_module.append(
AttentionWithFFN(
ffn_dims=ffn_dims,
ffn_dropout=ffn_dropout,
num_heads=num_heads,
key_dim=key_dim,
attn_dropout=attn_dropout,
name=f"pm_cross_attn_ffn_{layer_idx}",
)
)
self.perceptual_module = perceptual_module
# Temporal Latent Bottleneck Module
self.tlb_module = AttentionWithFFN(
ffn_dims=ffn_dims,
ffn_dropout=ffn_dropout,
num_heads=num_heads,
key_dim=key_dim,
attn_dropout=attn_dropout,
name=f"tlb_cross_attn_ffn",
)
def call(self, inputs, states):
# inputs => (batch, chunk_size, dims)
# states => [(batch, chunk_size, units)]
slow_stream = states[0]
fast_stream = inputs
for layer_idx, layer in enumerate(self.perceptual_module):
fast_stream = layer(query=fast_stream, key=fast_stream, value=fast_stream)
if layer_idx % self.r == 0:
fast_stream = layer(
query=fast_stream, key=slow_stream, value=slow_stream
)
slow_stream = self.tlb_module(
query=slow_stream, key=fast_stream, value=fast_stream
)
# Save the attention scores for later visualization.
if self.get_attention_scores:
self.attention_scores.append(self.tlb_module.attention_scores)
return fast_stream, [slow_stream]
TemporalLatentBottleneckModel
to encapsulate full modelHere, we just wrap the full model as to expose it for training.
class TemporalLatentBottleneckModel(keras.Model):
"""Model Trainer.
Args:
patch_layer ([`keras.layers.Layer`](/api/layers/base_layer#layer-class)): Patching layer.
custom_cell ([`keras.layers.Layer`](/api/layers/base_layer#layer-class)): Custom Recurrent Cell.
"""
def __init__(self, patch_layer, custom_cell, **kwargs):
super().__init__(**kwargs)
self.patch_layer = patch_layer
self.rnn = layers.RNN(custom_cell, name="rnn")
self.gap = layers.GlobalAveragePooling1D(name="gap")
self.head = layers.Dense(10, activation="softmax", dtype="float32", name="head")
def call(self, inputs):
x = self.patch_layer(inputs)
x = self.rnn(x)
x = self.gap(x)
outputs = self.head(x)
return outputs
To begin training, we now define the components individually and pass them as arguments
to our wrapper class, which will prepare the final model for training. We define a
PatchEmbed
layer, and the CustomCell
-based RNN.
# Build the model.
patch_layer = PatchEmbedding(
image_size=(config["image_size"], config["image_size"]),
patch_size=(config["patch_size"], config["patch_size"]),
embed_dim=config["embed_dim"],
chunk_size=config["chunk_size"],
)
custom_rnn_cell = CustomRecurrentCell(
chunk_size=config["chunk_size"],
r=config["r"],
num_layers=config["num_layers"],
ffn_dims=config["embed_dim"],
ffn_dropout=config["ffn_drop"],
num_heads=config["num_heads"],
key_dim=config["embed_dim"],
attn_dropout=config["attn_drop"],
)
model = TemporalLatentBottleneckModel(
patch_layer=patch_layer,
custom_cell=custom_rnn_cell,
)
We use the AdamW
optimizer since it has been shown to perform very well on several benchmark
tasks from an optimization perspective. It is a version of the keras.optimizers.Adam
optimizer, along with Weight Decay in place.
For a loss function, we make use of the keras.losses.SparseCategoricalCrossentropy
function that makes use of simple Cross-entropy between prediction and actual logits. We
also calculate accuracy on our data as a sanity-check.
optimizer = AdamW(
learning_rate=config["learning_rate"], weight_decay=config["weight_decay"]
)
model.compile(
optimizer=optimizer,
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
)
model.fit()
We pass the training dataset and run training.
history = model.fit(
train_ds,
epochs=config["epochs"],
validation_data=val_ds,
)
Epoch 1/30
20/20 [==============================] - 104s 3s/step - loss: 2.6284 - accuracy: 0.1010 - val_loss: 2.2835 - val_accuracy: 0.1251
Epoch 2/30
20/20 [==============================] - 35s 2s/step - loss: 2.2797 - accuracy: 0.1542 - val_loss: 2.1721 - val_accuracy: 0.1846
Epoch 3/30
20/20 [==============================] - 34s 2s/step - loss: 2.1989 - accuracy: 0.1883 - val_loss: 2.1288 - val_accuracy: 0.2241
Epoch 4/30
20/20 [==============================] - 34s 2s/step - loss: 2.1267 - accuracy: 0.2192 - val_loss: 2.0919 - val_accuracy: 0.2477
Epoch 5/30
20/20 [==============================] - 33s 2s/step - loss: 2.0653 - accuracy: 0.2393 - val_loss: 2.0134 - val_accuracy: 0.2671
Epoch 6/30
20/20 [==============================] - 34s 2s/step - loss: 2.0327 - accuracy: 0.2524 - val_loss: 2.0258 - val_accuracy: 0.2665
Epoch 7/30
20/20 [==============================] - 34s 2s/step - loss: 2.0047 - accuracy: 0.2598 - val_loss: 1.9871 - val_accuracy: 0.2831
Epoch 8/30
20/20 [==============================] - 34s 2s/step - loss: 1.9765 - accuracy: 0.2781 - val_loss: 1.9550 - val_accuracy: 0.2968
Epoch 9/30
20/20 [==============================] - 34s 2s/step - loss: 1.9432 - accuracy: 0.2883 - val_loss: 1.9559 - val_accuracy: 0.2969
Epoch 10/30
20/20 [==============================] - 33s 2s/step - loss: 1.9062 - accuracy: 0.3020 - val_loss: 1.8967 - val_accuracy: 0.3200
Epoch 11/30
20/20 [==============================] - 33s 2s/step - loss: 1.8741 - accuracy: 0.3158 - val_loss: 1.8648 - val_accuracy: 0.3330
Epoch 12/30
20/20 [==============================] - 33s 2s/step - loss: 1.8336 - accuracy: 0.3282 - val_loss: 1.7863 - val_accuracy: 0.3464
Epoch 13/30
20/20 [==============================] - 33s 2s/step - loss: 1.7931 - accuracy: 0.3434 - val_loss: 1.7364 - val_accuracy: 0.3669
Epoch 14/30
20/20 [==============================] - 34s 2s/step - loss: 1.7491 - accuracy: 0.3558 - val_loss: 1.7104 - val_accuracy: 0.3710
Epoch 15/30
20/20 [==============================] - 34s 2s/step - loss: 1.7182 - accuracy: 0.3686 - val_loss: 1.6883 - val_accuracy: 0.3866
Epoch 16/30
20/20 [==============================] - 33s 2s/step - loss: 1.6819 - accuracy: 0.3790 - val_loss: 1.6493 - val_accuracy: 0.3933
Epoch 17/30
20/20 [==============================] - 33s 2s/step - loss: 1.6594 - accuracy: 0.3873 - val_loss: 1.6021 - val_accuracy: 0.4161
Epoch 18/30
20/20 [==============================] - 33s 2s/step - loss: 1.6279 - accuracy: 0.3946 - val_loss: 1.5949 - val_accuracy: 0.4170
Epoch 19/30
20/20 [==============================] - 34s 2s/step - loss: 1.6127 - accuracy: 0.4015 - val_loss: 1.5672 - val_accuracy: 0.4239
Epoch 20/30
20/20 [==============================] - 33s 2s/step - loss: 1.5995 - accuracy: 0.4079 - val_loss: 1.5795 - val_accuracy: 0.4223
Epoch 21/30
20/20 [==============================] - 34s 2s/step - loss: 1.5809 - accuracy: 0.4167 - val_loss: 1.5294 - val_accuracy: 0.4390
Epoch 22/30
20/20 [==============================] - 34s 2s/step - loss: 1.5572 - accuracy: 0.4254 - val_loss: 1.5192 - val_accuracy: 0.4455
Epoch 23/30
20/20 [==============================] - 33s 2s/step - loss: 1.5468 - accuracy: 0.4291 - val_loss: 1.5243 - val_accuracy: 0.4424
Epoch 24/30
20/20 [==============================] - 34s 2s/step - loss: 1.5347 - accuracy: 0.4335 - val_loss: 1.4920 - val_accuracy: 0.4532
Epoch 25/30
20/20 [==============================] - 33s 2s/step - loss: 1.5245 - accuracy: 0.4387 - val_loss: 1.4805 - val_accuracy: 0.4584
Epoch 26/30
20/20 [==============================] - 33s 2s/step - loss: 1.5057 - accuracy: 0.4469 - val_loss: 1.4754 - val_accuracy: 0.4592
Epoch 27/30
20/20 [==============================] - 34s 2s/step - loss: 1.5013 - accuracy: 0.4457 - val_loss: 1.4688 - val_accuracy: 0.4619
Epoch 28/30
20/20 [==============================] - 33s 2s/step - loss: 1.4852 - accuracy: 0.4548 - val_loss: 1.4543 - val_accuracy: 0.4704
Epoch 29/30
20/20 [==============================] - 34s 2s/step - loss: 1.4728 - accuracy: 0.4570 - val_loss: 1.4437 - val_accuracy: 0.4751
Epoch 30/30
20/20 [==============================] - 34s 2s/step - loss: 1.4652 - accuracy: 0.4606 - val_loss: 1.4546 - val_accuracy: 0.4726
The model.fit()
will return a history
object, which stores the values of the metrics
generated during the training run (but it is ephemeral and needs to be saved manually).
We now display the Loss and Accuracy curves for the training and validation sets.
plt.plot(history.history["loss"], label="loss")
plt.plot(history.history["val_loss"], label="val_loss")
plt.legend()
plt.show()
plt.plot(history.history["accuracy"], label="accuracy")
plt.plot(history.history["val_accuracy"], label="val_accuracy")
plt.legend()
plt.show()
Now that we have trained our model, it is time for some visualizations. The Fast Stream (Transformers) processes a chunk of tokens. The Slow Stream processes each chunk and attends to tokens that are useful for the task.
In this section we visualize the attention map of the Slow Stream. This is done by extracting the attention scores from the TLB layer at each chunk's intersection and storing it within the RNN's state. This is followed by 'ballooning' it up and returning these values.
def score_to_viz(chunk_score):
# get the most attended token
chunk_viz = tf.math.reduce_max(chunk_score, axis=-2)
# get the mean across heads
chunk_viz = tf.math.reduce_mean(chunk_viz, axis=1)
return chunk_viz
# Get a batch of images and labels from the testing dataset
images, labels = next(iter(test_ds))
# Set the get_attn_scores flag to True
model.rnn.cell.get_attention_scores = True
# Run the model with the testing images and grab the
# attention scores.
outputs = model(images)
list_chunk_scores = model.rnn.cell.attention_scores
# Process the attention scores in order to visualize them
list_chunk_viz = [score_to_viz(x) for x in list_chunk_scores]
chunk_viz = tf.concat(list_chunk_viz[1:], axis=-1)
chunk_viz = tf.reshape(
chunk_viz,
(
config["batch_size"],
config["image_size"] // config["patch_size"],
config["image_size"] // config["patch_size"],
1,
),
)
upsampled_heat_map = layers.UpSampling2D(
size=(4, 4), interpolation="bilinear", dtype="float32"
)(chunk_viz)
Run the following code snippet to get different images and their attention maps.
# Sample a random image
index = random.randint(0, config["batch_size"])
orig_image = images[index]
overlay_image = upsampled_heat_map[index, ..., 0]
# Plot the visualization
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
ax[0].imshow(orig_image)
ax[0].set_title("Original:")
ax[0].axis("off")
image = ax[1].imshow(orig_image)
ax[1].imshow(
overlay_image,
cmap="inferno",
alpha=0.6,
extent=image.get_extent(),
)
ax[1].set_title("TLB Attention:")
plt.show()
This example has demonstrated an implementation of the Temporal Latent Bottleneck mechanism. The example highlights the use of compression and storage of historical states in the form of a Temporal Latent Bottleneck with regular updates from a Perceptual Module as an effective method to do so.
In the original paper, the authors have conducted highly extensive tests around different modalities ranging from Supervised Image Classification to applications in Reinforcement Learning.
While we have only displayed a method to apply this mechanism to Image Classification, it can be extended to other modalities too with minimal changes.
Note: While building this example we did not have the official code to refer to. This means that our implementation is inspired by the paper with no claims of being a complete reproduction. For more details on the training process one can head over to our GitHub repository.
Thanks to Aniket Didolkar (the first author) and Anirudh Goyal (the third author) for revieweing our work.
We would like to thank PyImageSearch for a Colab Pro account and JarvisLabs.ai for the GPU credits.