SegFormer
classkeras_cv.models.SegFormer(backbone, num_classes, projection_filters=256, **kwargs)
A Keras model implementing the SegFormer architecture for semantic segmentation.
References
Arguments
keras.Model
. The backbone network for the model that is
used as a feature extractor for the SegFormer encoder.
It is intended to be used only with the MiT backbone model which
was created specifically for SegFormers. It should either be a
keras_cv.models.backbones.backbone.Backbone
or a tf.keras.Model
that implements the pyramid_level_inputs
property with keys
"P2", "P3", "P4", and "P5" and layer names as
values.Example
Using the class with a backbone
:
import tensorflow as tf
import keras_cv
images = np.ones(shape=(1, 96, 96, 3))
labels = np.zeros(shape=(1, 96, 96, 1))
backbone = keras_cv.models.MiTBackbone.from_preset("mit_b0_imagenet")
model = keras_cv.models.segmentation.SegFormer(
num_classes=1, backbone=backbone,
)
# Evaluate model
model(images)
# Train model
model.compile(
optimizer="adam",
loss=keras.losses.BinaryCrossentropy(from_logits=False),
metrics=["accuracy"],
)
model.fit(images, labels, epochs=3)
from_preset
methodSegFormer.from_preset(
preset, num_classes, load_weights=None, input_shape=None, **kwargs
)
Instantiate SegFormer model from preset config and weights.
Arguments
None
, which follows whether the preset has
pretrained weights available.None
.If None
, the preset
value will be used.Example
# Load architecture and weights from preset
model = keras_cv.models.SegFormer.from_preset(
"segformer_b0_imagenet",
)
# Load randomly initialized model from preset architecture with weights
model = keras_cv.models.SegFormer.from_preset(
"segformer_b0_imagenet",
load_weights=False,
Preset name | Parameters | Description |
---|---|---|
mit_b0 | 3.32M | MiT (MixTransformer) model with 8 transformer blocks. |
mit_b1 | 13.16M | MiT (MixTransformer) model with 8 transformer blocks. |
mit_b2 | 24.20M | MiT (MixTransformer) model with 16 transformer blocks. |
mit_b3 | 44.08M | MiT (MixTransformer) model with 28 transformer blocks. |
mit_b4 | 60.85M | MiT (MixTransformer) model with 41 transformer blocks. |
mit_b5 | 81.45M | MiT (MixTransformer) model with 52 transformer blocks. |
mit_b0_imagenet | 3.32M | MiT (MixTransformer) model with 8 transformer blocks. Pre-trained on ImageNet-1K and scores 69% top-1 accuracy on the validation set. |
segformer_b0 | 3.72M | SegFormer model with MiTB0 backbone. |
segformer_b1 | 13.68M | SegFormer model with MiTB1 backbone. |
segformer_b2 | 24.73M | SegFormer model with MiTB2 backbone. |
segformer_b3 | 44.60M | SegFormer model with MiTB3 backbone. |
segformer_b4 | 61.37M | SegFormer model with MiTB4 backbone. |
segformer_b5 | 81.97M | SegFormer model with MiTB5 backbone. |
segformer_b0_imagenet | 3.72M | SegFormer model with a pretrained MiTB0 backbone. |
SegFormerB0
classkeras_cv.models.SegFormerB0(backbone, num_classes, projection_filters=256, **kwargs)
SegFormer model.
For transfer learning use cases, make sure to read the guide to transfer learning & fine-tuning.
Arguments
Example
input_data = tf.ones(shape=(8, 224, 224, 3))
# Randomly initialized backbone
backbone = keras_cv.models.MiTBackbone.from_preset("mit_b0_imagenet")
segformer = keras_cv.models.SegFormer(backbone=backbone, num_classes=19)
output = model(input_data)
SegFormerB1
classkeras_cv.models.SegFormerB1(backbone, num_classes, projection_filters=256, **kwargs)
SegFormer model.
For transfer learning use cases, make sure to read the guide to transfer learning & fine-tuning.
Arguments
Example
input_data = tf.ones(shape=(8, 224, 224, 3))
# Randomly initialized backbone
backbone = keras_cv.models.MiTBackbone.from_preset("mit_b0_imagenet")
segformer = keras_cv.models.SegFormer(backbone=backbone, num_classes=19)
output = model(input_data)
SegFormerB2
classkeras_cv.models.SegFormerB2(backbone, num_classes, projection_filters=256, **kwargs)
SegFormer model.
For transfer learning use cases, make sure to read the guide to transfer learning & fine-tuning.
Arguments
Example
input_data = tf.ones(shape=(8, 224, 224, 3))
# Randomly initialized backbone
backbone = keras_cv.models.MiTBackbone.from_preset("mit_b0_imagenet")
segformer = keras_cv.models.SegFormer(backbone=backbone, num_classes=19)
output = model(input_data)
SegFormerB3
classkeras_cv.models.SegFormerB3(backbone, num_classes, projection_filters=256, **kwargs)
SegFormer model.
For transfer learning use cases, make sure to read the guide to transfer learning & fine-tuning.
Arguments
Example
input_data = tf.ones(shape=(8, 224, 224, 3))
# Randomly initialized backbone
backbone = keras_cv.models.MiTBackbone.from_preset("mit_b0_imagenet")
segformer = keras_cv.models.SegFormer(backbone=backbone, num_classes=19)
output = model(input_data)
SegFormerB4
classkeras_cv.models.SegFormerB4(backbone, num_classes, projection_filters=256, **kwargs)
SegFormer model.
For transfer learning use cases, make sure to read the guide to transfer learning & fine-tuning.
Arguments
Example
input_data = tf.ones(shape=(8, 224, 224, 3))
# Randomly initialized backbone
backbone = keras_cv.models.MiTBackbone.from_preset("mit_b0_imagenet")
segformer = keras_cv.models.SegFormer(backbone=backbone, num_classes=19)
output = model(input_data)
SegFormerB5
classkeras_cv.models.SegFormerB5(backbone, num_classes, projection_filters=256, **kwargs)
SegFormer model.
For transfer learning use cases, make sure to read the guide to transfer learning & fine-tuning.
Arguments
Example
input_data = tf.ones(shape=(8, 224, 224, 3))
# Randomly initialized backbone
backbone = keras_cv.models.MiTBackbone.from_preset("mit_b0_imagenet")
segformer = keras_cv.models.SegFormer(backbone=backbone, num_classes=19)
output = model(input_data)