Author: Sayak Paul
Date created: 2021/10/20
Last modified: 2024/02/11
Description: MobileViT for image classification with combined benefits of convolutions and Transformers.
In this example, we implement the MobileViT architecture (Mehta et al.), which combines the benefits of Transformers (Vaswani et al.) and convolutions. With Transformers, we can capture long-range dependencies that result in global representations. With convolutions, we can capture spatial relationships that model locality.
Besides combining the properties of Transformers and convolutions, the authors introduce MobileViT as a general-purpose mobile-friendly backbone for different image recognition tasks. Their findings suggest that, performance-wise, MobileViT is better than other models with the same or higher complexity (MobileNetV3, for example), while being efficient on mobile devices.
Note: This example should be run with Tensorflow 2.13 and higher.
import os
import tensorflow as tf
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
from keras import layers
from keras import backend
import tensorflow_datasets as tfds
# Values are from table 4.
patch_size = 4 # 2x2, for the Transformer blocks.
image_size = 256
expansion_factor = 2 # expansion factor for the MobileNetV2 blocks.
The MobileViT architecture is comprised of the following blocks:
def conv_block(x, filters=16, kernel_size=3, strides=2):
conv_layer = layers.Conv2D(
return conv_layer(x)
# Reference:
def correct_pad(inputs, kernel_size):
img_dim = 2 if backend.image_data_format() == "channels_first" else 1
input_size = inputs.shape[img_dim : (img_dim + 2)]
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
if input_size[0] is None:
adjust = (1, 1)
adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2)
correct = (kernel_size[0] // 2, kernel_size[1] // 2)
return (
(correct[0] - adjust[0], correct[0]),
(correct[1] - adjust[1], correct[1]),
# Reference:
def inverted_residual_block(x, expanded_channels, output_channels, strides=1):
m = layers.Conv2D(expanded_channels, 1, padding="same", use_bias=False)(x)
m = layers.BatchNormalization()(m)
m = keras.activations.swish(m)
if strides == 2:
m = layers.ZeroPadding2D(padding=correct_pad(m, 3))(m)
m = layers.DepthwiseConv2D(
3, strides=strides, padding="same" if strides == 1 else "valid", use_bias=False
m = layers.BatchNormalization()(m)
m = keras.activations.swish(m)
m = layers.Conv2D(output_channels, 1, padding="same", use_bias=False)(m)
m = layers.BatchNormalization()(m)
if keras.ops.equal(x.shape[-1], output_channels) and strides == 1:
return layers.Add()([m, x])
return m
# Reference:
def mlp(x, hidden_units, dropout_rate):
for units in hidden_units:
x = layers.Dense(units, activation=keras.activations.swish)(x)
x = layers.Dropout(dropout_rate)(x)
return x
def transformer_block(x, transformer_layers, projection_dim, num_heads=2):
for _ in range(transformer_layers):
# Layer normalization 1.
x1 = layers.LayerNormalization(epsilon=1e-6)(x)
# Create a multi-head attention layer.
attention_output = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
)(x1, x1)
# Skip connection 1.
x2 = layers.Add()([attention_output, x])
# Layer normalization 2.
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
# MLP.
x3 = mlp(
hidden_units=[x.shape[-1] * 2, x.shape[-1]],
# Skip connection 2.
x = layers.Add()([x3, x2])
return x
def mobilevit_block(x, num_blocks, projection_dim, strides=1):
# Local projection with convolutions.
local_features = conv_block(x, filters=projection_dim, strides=strides)
local_features = conv_block(
local_features, filters=projection_dim, kernel_size=1, strides=strides
# Unfold into patches and then pass through Transformers.
num_patches = int((local_features.shape[1] * local_features.shape[2]) / patch_size)
non_overlapping_patches = layers.Reshape((patch_size, num_patches, projection_dim))(
global_features = transformer_block(
non_overlapping_patches, num_blocks, projection_dim
# Fold into conv-like feature-maps.
folded_feature_map = layers.Reshape((*local_features.shape[1:-1], projection_dim))(
# Apply point-wise conv -> concatenate with the input features.
folded_feature_map = conv_block(
folded_feature_map, filters=x.shape[-1], kernel_size=1, strides=strides
local_global_features = layers.Concatenate(axis=-1)([x, folded_feature_map])
# Fuse the local and global features using a convoluion layer.
local_global_features = conv_block(
local_global_features, filters=projection_dim, strides=strides
return local_global_features
More on the MobileViT block:
(h, w, num_channels)
.(p, n, num_channels)
where p
is the area of a small patch, and n
is (h * w) / p
. So, we end up with n
non-overlapping patches.(h, w, num_channels)
resembling a feature map coming out of convolutions.Vectors A and B are then passed through two more convolutional layers to fuse the local and global representations. Notice how the spatial resolution of the final vector remains unchanged at this point. The authors also present an explanation of how the MobileViT block resembles a convolution block of a CNN. For more details, please refer to the original paper.
Next, we combine these blocks together and implement the MobileViT architecture (XXS variant). The following figure (taken from the original paper) presents a schematic representation of the architecture:
def create_mobilevit(num_classes=5):
inputs = keras.Input((image_size, image_size, 3))
x = layers.Rescaling(scale=1.0 / 255)(inputs)
# Initial conv-stem -> MV2 block.
x = conv_block(x, filters=16)
x = inverted_residual_block(
x, expanded_channels=16 * expansion_factor, output_channels=16
# Downsampling with MV2 block.
x = inverted_residual_block(
x, expanded_channels=16 * expansion_factor, output_channels=24, strides=2
x = inverted_residual_block(
x, expanded_channels=24 * expansion_factor, output_channels=24
x = inverted_residual_block(
x, expanded_channels=24 * expansion_factor, output_channels=24
# First MV2 -> MobileViT block.
x = inverted_residual_block(
x, expanded_channels=24 * expansion_factor, output_channels=48, strides=2
x = mobilevit_block(x, num_blocks=2, projection_dim=64)
# Second MV2 -> MobileViT block.
x = inverted_residual_block(
x, expanded_channels=64 * expansion_factor, output_channels=64, strides=2
x = mobilevit_block(x, num_blocks=4, projection_dim=80)
# Third MV2 -> MobileViT block.
x = inverted_residual_block(
x, expanded_channels=80 * expansion_factor, output_channels=80, strides=2
x = mobilevit_block(x, num_blocks=3, projection_dim=96)
x = conv_block(x, filters=320, kernel_size=1, strides=1)
# Classification head.
x = layers.GlobalAvgPool2D()(x)
outputs = layers.Dense(num_classes, activation="softmax")(x)
return keras.Model(inputs, outputs)
mobilevit_xxs = create_mobilevit()
## Dataset preparation
We will be using the
dataset to demonstrate the model. Unlike other Transformer-based architectures,
MobileViT uses a simple augmentation pipeline primarily because it has the properties
of a CNN.
batch_size = 64
auto =
resize_bigger = 280
num_classes = 5
def preprocess_dataset(is_training=True):
def _pp(image, label):
if is_training:
# Resize to a bigger spatial resolution and take the random
# crops.
image = tf.image.resize(image, (resize_bigger, resize_bigger))
image = tf.image.random_crop(image, (image_size, image_size, 3))
image = tf.image.random_flip_left_right(image)
image = tf.image.resize(image, (image_size, image_size))
label = tf.one_hot(label, depth=num_classes)
return image, label
return _pp
def prepare_dataset(dataset, is_training=True):
if is_training:
dataset = dataset.shuffle(batch_size * 10)
dataset =, num_parallel_calls=auto)
return dataset.batch(batch_size).prefetch(auto)
train_dataset, val_dataset = tfds.load(
"tf_flowers", split=["train[:90%]", "train[90%:]"], as_supervised=True
num_train = train_dataset.cardinality()
num_val = val_dataset.cardinality()
print(f"Number of training examples: {num_train}")
print(f"Number of validation examples: {num_val}")
train_dataset = prepare_dataset(train_dataset, is_training=True)
val_dataset = prepare_dataset(val_dataset, is_training=False)
Number of training examples: 3303
Number of validation examples: 367
learning_rate = 0.002
label_smoothing_factor = 0.1
epochs = 30
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
loss_fn = keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing_factor)
def run_experiment(epochs=epochs):
mobilevit_xxs = create_mobilevit(num_classes=num_classes)
mobilevit_xxs.compile(optimizer=optimizer, loss=loss_fn, metrics=["accuracy"])
# When using `save_weights_only=True` in `ModelCheckpoint`, the filepath provided must end in `.weights.h5`
checkpoint_filepath = "/tmp/checkpoint.weights.h5"
checkpoint_callback = keras.callbacks.ModelCheckpoint(
_, accuracy = mobilevit_xxs.evaluate(val_dataset)
print(f"Validation accuracy: {round(accuracy * 100, 2)}%")
return mobilevit_xxs
mobilevit_xxs = run_experiment()
Epoch 1/30
52/52 [==============================] - 47s 459ms/step - loss: 1.3397 - accuracy: 0.4832 - val_loss: 1.7250 - val_accuracy: 0.1662
Epoch 2/30
52/52 [==============================] - 21s 404ms/step - loss: 1.1167 - accuracy: 0.6210 - val_loss: 1.9844 - val_accuracy: 0.1907
Epoch 3/30
52/52 [==============================] - 21s 403ms/step - loss: 1.0217 - accuracy: 0.6709 - val_loss: 1.8187 - val_accuracy: 0.1907
Epoch 4/30
52/52 [==============================] - 21s 409ms/step - loss: 0.9682 - accuracy: 0.7048 - val_loss: 2.0329 - val_accuracy: 0.1907
Epoch 5/30
52/52 [==============================] - 21s 408ms/step - loss: 0.9552 - accuracy: 0.7196 - val_loss: 2.1150 - val_accuracy: 0.1907
Epoch 6/30
52/52 [==============================] - 21s 407ms/step - loss: 0.9186 - accuracy: 0.7318 - val_loss: 2.9713 - val_accuracy: 0.1907
Epoch 7/30
52/52 [==============================] - 21s 407ms/step - loss: 0.8986 - accuracy: 0.7457 - val_loss: 3.2062 - val_accuracy: 0.1907
Epoch 8/30
52/52 [==============================] - 21s 408ms/step - loss: 0.8831 - accuracy: 0.7542 - val_loss: 3.8631 - val_accuracy: 0.1907
Epoch 9/30
52/52 [==============================] - 21s 408ms/step - loss: 0.8433 - accuracy: 0.7714 - val_loss: 1.8029 - val_accuracy: 0.3542
Epoch 10/30
52/52 [==============================] - 21s 408ms/step - loss: 0.8489 - accuracy: 0.7763 - val_loss: 1.7920 - val_accuracy: 0.4796
Epoch 11/30
52/52 [==============================] - 21s 409ms/step - loss: 0.8256 - accuracy: 0.7884 - val_loss: 1.4992 - val_accuracy: 0.5477
Epoch 12/30
52/52 [==============================] - 21s 407ms/step - loss: 0.7859 - accuracy: 0.8123 - val_loss: 0.9236 - val_accuracy: 0.7330
Epoch 13/30
52/52 [==============================] - 21s 409ms/step - loss: 0.7702 - accuracy: 0.8159 - val_loss: 0.8059 - val_accuracy: 0.8011
Epoch 14/30
52/52 [==============================] - 21s 403ms/step - loss: 0.7670 - accuracy: 0.8153 - val_loss: 1.1535 - val_accuracy: 0.7084
Epoch 15/30
52/52 [==============================] - 21s 408ms/step - loss: 0.7332 - accuracy: 0.8344 - val_loss: 0.7746 - val_accuracy: 0.8147
Epoch 16/30
52/52 [==============================] - 21s 404ms/step - loss: 0.7284 - accuracy: 0.8335 - val_loss: 1.0342 - val_accuracy: 0.7330
Epoch 17/30
52/52 [==============================] - 21s 409ms/step - loss: 0.7484 - accuracy: 0.8262 - val_loss: 1.0523 - val_accuracy: 0.7112
Epoch 18/30
52/52 [==============================] - 21s 408ms/step - loss: 0.7209 - accuracy: 0.8450 - val_loss: 0.8146 - val_accuracy: 0.8174
Epoch 19/30
52/52 [==============================] - 21s 409ms/step - loss: 0.7141 - accuracy: 0.8435 - val_loss: 0.8016 - val_accuracy: 0.7875
Epoch 20/30
52/52 [==============================] - 21s 410ms/step - loss: 0.7075 - accuracy: 0.8435 - val_loss: 0.9352 - val_accuracy: 0.7439
Epoch 21/30
52/52 [==============================] - 21s 406ms/step - loss: 0.7066 - accuracy: 0.8504 - val_loss: 1.0171 - val_accuracy: 0.7139
Epoch 22/30
52/52 [==============================] - 21s 405ms/step - loss: 0.6913 - accuracy: 0.8532 - val_loss: 0.7059 - val_accuracy: 0.8610
Epoch 23/30
52/52 [==============================] - 21s 408ms/step - loss: 0.6681 - accuracy: 0.8671 - val_loss: 0.8007 - val_accuracy: 0.8147
Epoch 24/30
52/52 [==============================] - 21s 409ms/step - loss: 0.6636 - accuracy: 0.8747 - val_loss: 0.9490 - val_accuracy: 0.7302
Epoch 25/30
52/52 [==============================] - 21s 408ms/step - loss: 0.6637 - accuracy: 0.8722 - val_loss: 0.6913 - val_accuracy: 0.8556
Epoch 26/30
52/52 [==============================] - 21s 406ms/step - loss: 0.6443 - accuracy: 0.8837 - val_loss: 1.0483 - val_accuracy: 0.7139
Epoch 27/30
52/52 [==============================] - 21s 407ms/step - loss: 0.6555 - accuracy: 0.8695 - val_loss: 0.9448 - val_accuracy: 0.7602
Epoch 28/30
52/52 [==============================] - 21s 409ms/step - loss: 0.6409 - accuracy: 0.8807 - val_loss: 0.9337 - val_accuracy: 0.7302
Epoch 29/30
52/52 [==============================] - 21s 408ms/step - loss: 0.6300 - accuracy: 0.8910 - val_loss: 0.7461 - val_accuracy: 0.8256
Epoch 30/30
52/52 [==============================] - 21s 408ms/step - loss: 0.6093 - accuracy: 0.8968 - val_loss: 0.8651 - val_accuracy: 0.7766
6/6 [==============================] - 0s 65ms/step - loss: 0.7059 - accuracy: 0.8610
Validation accuracy: 86.1%
# Serialize the model as a SavedModel., "mobilevit_xxs")
# Convert to TFLite. This form of quantization is called
# post-training dynamic-range quantization in TFLite.
converter = tf.lite.TFLiteConverter.from_saved_model("mobilevit_xxs")
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # Enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS, # Enable TensorFlow ops.
tflite_model = converter.convert()
open("mobilevit_xxs.tflite", "wb").write(tflite_model)