Author: Aritra Roy Gosthipaty, Ayush Thakur (equal contribution)
Date created: 2022/01/12
Last modified: 2024/01/15
Description: A Transformer-based architecture for video classification.
View in Colab • GitHub source
Videos are sequences of images. Let's assume you have an image representation model (CNN, ViT, etc.) and a sequence model (RNN, LSTM, etc.) at hand. We ask you to tweak the model for video classification. The simplest approach would be to apply the image model to individual frames, use the sequence model to learn sequences of image features, then apply a classification head on the learned sequence representation. The Keras example Video Classification with a CNN-RNN Architecture explains this approach in detail. Alernatively, you can also build a hybrid Transformer-based model for video classification as shown in the Keras example Video Classification with Transformers.
In this example, we minimally implement ViViT: A Video Vision Transformer by Arnab et al., a pure Transformer-based model for video classification. The authors propose a novel embedding scheme and a number of Transformer variants to model video clips. We implement the embedding scheme and one of the variants of the Transformer architecture, for simplicity.
This example requires medmnist
package, which can be installed
by running the code cell below.
!pip install -qq medmnist
import os
import io
import imageio
import medmnist
import ipywidgets
import numpy as np
import tensorflow as tf # for data preprocessing only
import keras
from keras import layers, ops
# Setting seed for reproducibility
SEED = 42
os.environ["TF_CUDNN_DETERMINISTIC"] = "1"
keras.utils.set_random_seed(SEED)
The hyperparameters are chosen via hyperparameter search. You can learn more about the process in the "conclusion" section.
# DATA
DATASET_NAME = "organmnist3d"
BATCH_SIZE = 32
AUTO = tf.data.AUTOTUNE
INPUT_SHAPE = (28, 28, 28, 1)
NUM_CLASSES = 11
# OPTIMIZER
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5
# TRAINING
EPOCHS = 60
# TUBELET EMBEDDING
PATCH_SIZE = (8, 8, 8)
NUM_PATCHES = (INPUT_SHAPE[0] // PATCH_SIZE[0]) ** 2
# ViViT ARCHITECTURE
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 128
NUM_HEADS = 8
NUM_LAYERS = 8
For our example we use the MedMNIST v2: A Large-Scale Lightweight Benchmark for 2D and 3D Biomedical Image Classification dataset. The videos are lightweight and easy to train on.
def download_and_prepare_dataset(data_info: dict):
"""Utility function to download the dataset.
Arguments:
data_info (dict): Dataset metadata.
"""
data_path = keras.utils.get_file(origin=data_info["url"], md5_hash=data_info["MD5"])
with np.load(data_path) as data:
# Get videos
train_videos = data["train_images"]
valid_videos = data["val_images"]
test_videos = data["test_images"]
# Get labels
train_labels = data["train_labels"].flatten()
valid_labels = data["val_labels"].flatten()
test_labels = data["test_labels"].flatten()
return (
(train_videos, train_labels),
(valid_videos, valid_labels),
(test_videos, test_labels),
)
# Get the metadata of the dataset
info = medmnist.INFO[DATASET_NAME]
# Get the dataset
prepared_dataset = download_and_prepare_dataset(info)
(train_videos, train_labels) = prepared_dataset[0]
(valid_videos, valid_labels) = prepared_dataset[1]
(test_videos, test_labels) = prepared_dataset[2]
tf.data
pipelinedef preprocess(frames: tf.Tensor, label: tf.Tensor):
"""Preprocess the frames tensors and parse the labels."""
# Preprocess images
frames = tf.image.convert_image_dtype(
frames[
..., tf.newaxis
], # The new axis is to help for further processing with Conv3D layers
tf.float32,
)
# Parse label
label = tf.cast(label, tf.float32)
return frames, label
def prepare_dataloader(
videos: np.ndarray,
labels: np.ndarray,
loader_type: str = "train",
batch_size: int = BATCH_SIZE,
):
"""Utility function to prepare the dataloader."""
dataset = tf.data.Dataset.from_tensor_slices((videos, labels))
if loader_type == "train":
dataset = dataset.shuffle(BATCH_SIZE * 2)
dataloader = (
dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
.batch(batch_size)
.prefetch(tf.data.AUTOTUNE)
)
return dataloader
trainloader = prepare_dataloader(train_videos, train_labels, "train")
validloader = prepare_dataloader(valid_videos, valid_labels, "valid")
testloader = prepare_dataloader(test_videos, test_labels, "test")
In ViTs, an image is divided into patches, which are then spatially flattened, a process known as tokenization. For a video, one can repeat this process for individual frames. Uniform frame sampling as suggested by the authors is a tokenization scheme in which we sample frames from the video clip and perform simple ViT tokenization.
Uniform Frame Sampling Source |
Tubelet Embedding is different in terms of capturing temporal information from the video. First, we extract volumes from the video – these volumes contain patches of the frame and the temporal information as well. The volumes are then flattened to build video tokens.
Tubelet Embedding Source |
class TubeletEmbedding(layers.Layer):
def __init__(self, embed_dim, patch_size, **kwargs):
super().__init__(**kwargs)
self.projection = layers.Conv3D(
filters=embed_dim,
kernel_size=patch_size,
strides=patch_size,
padding="VALID",
)
self.flatten = layers.Reshape(target_shape=(-1, embed_dim))
def call(self, videos):
projected_patches = self.projection(videos)
flattened_patches = self.flatten(projected_patches)
return flattened_patches
This layer adds positional information to the encoded video tokens.
class PositionalEncoder(layers.Layer):
def __init__(self, embed_dim, **kwargs):
super().__init__(**kwargs)
self.embed_dim = embed_dim
def build(self, input_shape):
_, num_tokens, _ = input_shape
self.position_embedding = layers.Embedding(
input_dim=num_tokens, output_dim=self.embed_dim
)
self.positions = ops.arange(0, num_tokens, 1)
def call(self, encoded_tokens):
# Encode the positions and add it to the encoded tokens
encoded_positions = self.position_embedding(self.positions)
encoded_tokens = encoded_tokens + encoded_positions
return encoded_tokens
The authors suggest 4 variants of Vision Transformer:
In this example, we will implement the Spatio-temporal attention model for simplicity. The following code snippet is heavily inspired from Image classification with Vision Transformer. One can also refer to the official repository of ViViT which contains all the variants, implemented in JAX.
def create_vivit_classifier(
tubelet_embedder,
positional_encoder,
input_shape=INPUT_SHAPE,
transformer_layers=NUM_LAYERS,
num_heads=NUM_HEADS,
embed_dim=PROJECTION_DIM,
layer_norm_eps=LAYER_NORM_EPS,
num_classes=NUM_CLASSES,
):
# Get the input layer
inputs = layers.Input(shape=input_shape)
# Create patches.
patches = tubelet_embedder(inputs)
# Encode patches.
encoded_patches = positional_encoder(patches)
# Create multiple layers of the Transformer block.
for _ in range(transformer_layers):
# Layer normalization and MHSA
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
attention_output = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=embed_dim // num_heads, dropout=0.1
)(x1, x1)
# Skip connection
x2 = layers.Add()([attention_output, encoded_patches])
# Layer Normalization and MLP
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
x3 = keras.Sequential(
[
layers.Dense(units=embed_dim * 4, activation=ops.gelu),
layers.Dense(units=embed_dim, activation=ops.gelu),
]
)(x3)
# Skip connection
encoded_patches = layers.Add()([x3, x2])
# Layer normalization and Global average pooling.
representation = layers.LayerNormalization(epsilon=layer_norm_eps)(encoded_patches)
representation = layers.GlobalAvgPool1D()(representation)
# Classify outputs.
outputs = layers.Dense(units=num_classes, activation="softmax")(representation)
# Create the Keras model.
model = keras.Model(inputs=inputs, outputs=outputs)
return model
def run_experiment():
# Initialize model
model = create_vivit_classifier(
tubelet_embedder=TubeletEmbedding(
embed_dim=PROJECTION_DIM, patch_size=PATCH_SIZE
),
positional_encoder=PositionalEncoder(embed_dim=PROJECTION_DIM),
)
# Compile the model with the optimizer, loss function
# and the metrics.
optimizer = keras.optimizers.Adam(learning_rate=LEARNING_RATE)
model.compile(
optimizer=optimizer,
loss="sparse_categorical_crossentropy",
metrics=[
keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
],
)
# Train the model.
_ = model.fit(trainloader, epochs=EPOCHS, validation_data=validloader)
_, accuracy, top_5_accuracy = model.evaluate(testloader)
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
return model
model = run_experiment()
Test accuracy: 76.72%
Test top 5 accuracy: 97.54%
NUM_SAMPLES_VIZ = 25
testsamples, labels = next(iter(testloader))
testsamples, labels = testsamples[:NUM_SAMPLES_VIZ], labels[:NUM_SAMPLES_VIZ]
ground_truths = []
preds = []
videos = []
for i, (testsample, label) in enumerate(zip(testsamples, labels)):
# Generate gif
testsample = np.reshape(testsample.numpy(), (-1, 28, 28))
with io.BytesIO() as gif:
imageio.mimsave(gif, (testsample * 255).astype("uint8"), "GIF", fps=5)
videos.append(gif.getvalue())
# Get model prediction
output = model.predict(ops.expand_dims(testsample, axis=0))[0]
pred = np.argmax(output, axis=0)
ground_truths.append(label.numpy().astype("int"))
preds.append(pred)
def make_box_for_grid(image_widget, fit):
"""Make a VBox to hold caption/image for demonstrating option_fit values.
Source: https://ipywidgets.readthedocs.io/en/latest/examples/Widget%20Styling.html
"""
# Make the caption
if fit is not None:
fit_str = "'{}'".format(fit)
else:
fit_str = str(fit)
h = ipywidgets.HTML(value="" + str(fit_str) + "")
# Make the green box with the image widget inside it
boxb = ipywidgets.widgets.Box()
boxb.children = [image_widget]
# Compose into a vertical box
vb = ipywidgets.widgets.VBox()
vb.layout.align_items = "center"
vb.children = [h, boxb]
return vb
boxes = []
for i in range(NUM_SAMPLES_VIZ):
ib = ipywidgets.widgets.Image(value=videos[i], width=100, height=100)
true_class = info["label"][str(ground_truths[i])]
pred_class = info["label"][str(preds[i])]
caption = f"T: {true_class} | P: {pred_class}"
boxes.append(make_box_for_grid(ib, caption))
ipywidgets.widgets.GridBox(
boxes, layout=ipywidgets.widgets.Layout(grid_template_columns="repeat(5, 200px)")
)
With a vanilla implementation, we achieve ~79-80% Top-1 accuracy on the test dataset.
The hyperparameters used in this tutorial were finalized by running a hyperparameter search using W&B Sweeps. You can find out our sweeps result here and our quick analysis of the results here.
For further improvement, you could look into the following:
We would like to thank Anurag Arnab (first author of ViViT) for helpful discussion. We are grateful to Weights and Biases program for helping with GPU credits.
You can use the trained model hosted on Hugging Face Hub and try the demo on Hugging Face Spaces.