SegmentAnythingModel
classkeras_cv.models.SegmentAnythingModel(
backbone, prompt_encoder, mask_decoder, **kwargs
)
The Segment Anything (SAM) Model.
Arguments
References
Example
>>> import numpy as np
>>> from keras_cv.src.models import ViTDetBBackbone
>>> from keras_cv.src.models import SAMPromptEncoder
>>> from keras_cv.src.models import SAMMaskDecoder
Create all the components of the SAM model:
>>> backbone = ViTDetBBackbone()
>>> prompt_encoder = SAMPromptEncoder()
>>> mask_decoder = SAMMaskDecoder()
Instantiate the model:
>>> sam = SegmentAnythingModel(
... backbone=backbone,
... prompt_encoder=prompt_encoder,
... mask_decoder=mask_decoder
... )
Define the input of the backbone. This must be a batch of images of shape
(1024, 1024, 3)
for the ViT backbone we are using:
>>> image = np.ones((1, 1024, 1024, 3))
SAM works by prompting the input images. There are three ways to prompt:
(1) Labelled Points: Foreground points (points with label 1) are encoded such that the output masks generated by the mask decoder contain them and background points (points with label 0) are encoded such that the generated masks don't contain them. (2) Box: A box tells the model which part/crop of the image to segment. (3) Mask: An input mask can be used to refine the output of the mask decoder.
These prompts can be mixed and matched but at least one of the prompts must be present. To turn off a particular prompt, simply exclude it from the inputs to the model.
TODO(ianstenbit): Remove the need for the 1
axes, and fix the box shape
.
(1) For points prompts, the expected shape is (batch, num_points, 2)
.
The labels must have a corresponding shape of (batch, num_points)
.
(2) For box prompt, the expected shape is (batch, 1, 2, 2)
.
(3) Similarly, mask prompts have shape (batch, 1, H, W, 1)
.
For example, to pass in all the prompts, do:
>>> points = np.array([[[512., 512.], [100., 100.]]])
>>> # For labels: 1 means foreground point, 0 means background
>>> labels = np.array([[1., 0.]])
>>> box = np.array([[[[384., 384.], [640., 640.]]]])
>>> input_mask = np.ones((1, 1, 256, 256, 1))
Prepare an input dictionary:
>>> inputs = {
... "images": image,
... "points": points,
... "labels": labels,
... "boxes": box,
... "masks": input_mask
... }
...
>>> outputs = sam.predict(inputs)
>>> masks, iou_pred = outputs["masks"], outputs["iou_pred"]
The first mask in the output masks
(i.e. masks[:, 0, ...]
) is the best
mask predicted by the model based on the prompts. Other masks
(i.e. masks[:, 1:, ...]
) are alternate predictions that can be used if
they are desired over the first one.
Now, in case of only points and box prompts, simply exclude the masks:
>>> inputs = {
... "images": image,
... "points": points,
... "labels": labels,
... "boxes": box,
... }
...
>>> outputs = sam.predict(inputs)
>>> masks, iou_pred = outputs["masks"], outputs["iou_pred"]
TODO(ianstenbit): Remove the need for this padding
.
Another example is that only points prompts are present. Note that if point prompts are present but no box prompt is present, the points must be padded using a zero point and -1 label:
>>> padded_points = np.concatenate(
... [points, np.zeros((1, 1, 2))], axis=1
... )
...
>>> padded_labels = np.concatenate(
... [labels, -np.ones((1, 1))], axis=1
... )
>>> inputs = {
... "images": image,
... "points": padded_points,
... "labels": padded_labels,
... }
...
>>> outputs = sam.predict(inputs)
>>> masks, iou_pred = outputs["masks"], outputs["iou_pred"]
Note that the segment anything model only supports inference and training
isn't support yet. So, calling the fit
method will fail for now.
from_preset
methodSegmentAnythingModel.from_preset()
Instantiate SegmentAnythingModel model from preset config and weights.
Arguments
None
, which follows whether the preset has
pretrained weights available.None
.If None
, the preset
value will be used.Example
# Load architecture and weights from preset
model = keras_cv.models.SegmentAnythingModel.from_preset(
"resnet50_imagenet",
)
# Load randomly initialized model from preset architecture with weights
model = keras_cv.models.SegmentAnythingModel.from_preset(
"resnet50_imagenet",
load_weights=False,
Preset name | Parameters | Description |
---|---|---|
sam_base_sa1b | 93.74M | The base SAM model trained on the SA1B dataset. |
sam_large_sa1b | 312.34M | The large SAM model trained on the SA1B dataset. |
sam_huge_sa1b | 641.09M | The huge SAM model trained on the SA1B dataset. |
SAMMaskDecoder
classkeras_cv.models.SAMMaskDecoder(
transformer_dim=256,
transformer=None,
num_multimask_outputs=3,
iou_head_depth=3,
iou_head_hidden_dim=256,
activation="gelu",
**kwargs
)
Mask decoder for the Segment Anything Model (SAM).
This lightweight module efficiently maps the image embedding and a set of prompt embeddings to an output mask. Before applying the transformer decoder, the layer first inserts into the set of prompt embeddings a learned output token embedding that will be used at the decoder's output. For simplicity, these embeddings (not including the image embedding) are collectively called "tokens".
The image embeddings, positional image embeddings, and tokens are passed through a transformer decoder. After running the decoder, the layer upsamples the updated image embedding by 4x with two transposed convolutional layers (now it's downscaled 4x relative to the input image). Then, the tokens attend once more to the image embedding and the updated output token embedding are passed to a small 3-layer MLP that outputs a vector matching the channel dimension of the upscaled image embedding. Finally, a mask is predicted with a spatially point-wise product between the upscaled image embedding and the MLP's output.
Arguments
256
.None
. When None
, a
keras_cv.models.TwoWayTransformer
layer is used.1 + num_multimask_outputs
. Defaults
to 3
.3
.256
."gelu"
.References
SAMPromptEncoder
classkeras_cv.models.SAMPromptEncoder(
embed_dim=256,
image_embedding_size=(64, 64),
input_image_size=(1024, 1024),
mask_in_chans=16,
activation="gelu",
**kwargs
)
Prompt Encoder for the Segment Anything Model (SAM).
The prompt encoder generates encodings for three types of prompts:
First, the point prompts and box prompts are concatenated and positional encodings are generated using random spatial frequencies. A point is represented as the sum of a positional encoding of the point's location and one of two learned embeddings that indicate if the point is either in the foreground or background. A box is represented by an embedding pair:
(1) the positional encoding of its top-left corner summed with a learned embedding representing "top-left corner" and (2) the same structure but using a learned embedding indicating "bottom-right corner".
The box and point encodings are referred to as "sparse encodings"
If a mask prompt is passed, a convolutional neural net is used to downscale it to generate "dense encodings". If no mask prompt is passed, an embedding layer is used instead to generate a "no mask" embedding.
Arguments
256
.(64, 64)
.(1024, 1024)
.16
."gelu"
.References
TwoWayTransformer
classkeras_cv.models.TwoWayTransformer(
depth=2,
embed_dim=256,
num_heads=8,
mlp_dim=2048,
activation="relu",
attention_downsample_rate=2,
**kwargs
)
A two-way cross-attention transformer decoder.
A transformer decoder that attends to an input image using queries whose positional embedding is supplied.
The transformer decoder design is shown in [1]_. Each decoder layer performs 4 steps: (1) self-attention on the tokens, (2) cross-attention from tokens (as queries) to the image embedding, (3) a point-wise MLP updates each token, and (4) cross-attention from the image embedding (as queries) to tokens. This last step updates the image embedding with prompt information. Each self/cross-attention and MLP has a residual connection and layer normalization.
To ensure the decoder has access to critical geometric information the positional encodings are added to the image embedding whenever they participate in an attention layer. Additionally, the entire original prompt tokens (including their positional encodings) are re-added to the updated tokens whenever they participate in an attention layer. This allows for a strong dependence on both the prompt token's geometric location and type.
Arguments
2
.256
.8
.2048
."relu"
.2
.References
MultiHeadAttentionWithDownsampling
classkeras_cv.layers.MultiHeadAttentionWithDownsampling(
num_heads, key_dim, downsample_rate=1, **kwargs
)
Multi-Head Attention with downsampling.
An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and values.
This layer first downscales the features of input queries, keys, and values using a dense layer. Multi-head attention is then performed and the attention map is projected back (upscaled) to the number of input features.
Arguments
key_dim
are
projected down to key_dim // downsample_rate
.References
TwoWayMultiHeadAttention
classkeras_cv.layers.TwoWayMultiHeadAttention(
num_heads,
key_dim,
mlp_dim,
skip_first_layer_pe,
attention_downsample_rate=2,
activation="relu",
**kwargs
)
Two-way multi-head attention layer.
Arguments
References
RandomFrequencyPositionalEmbeddings
classkeras_cv.layers.RandomFrequencyPositionalEmbeddings(
num_positional_features, scale, **kwargs
)
Positional encoding using random spatial frequencies.
This layer maps coordinates/points in 2D space to positional encodings using random spatial frequencies.
Arguments
References