SAMMaskDecoder
classkeras_hub.layers.SAMMaskDecoder(
hidden_size,
num_layers,
intermediate_dim,
num_heads,
embedding_dim=256,
num_multimask_outputs=3,
iou_head_depth=3,
iou_head_hidden_dim=256,
activation="gelu",
**kwargs
)
Mask decoder for the Segment Anything Model (SAM).
This lightweight module efficiently maps the image embedding and a set of prompt embeddings to an output mask. Before applying the transformer decoder, the layer first inserts into the set of prompt embeddings a learned output token embedding that will be used at the decoder's output. For simplicity, these embeddings (not including the image embedding) are collectively called "tokens".
The image embeddings, positional image embeddings, and tokens are passed through a transformer decoder. After running the decoder, the layer upsamples the updated image embedding by 4x with two transposed convolutional layers (now it's downscaled 4x relative to the input image). Then, the tokens attend once more to the image embedding and the updated output token embedding are passed to a small 3-layer MLP that outputs a vector matching the channel dimension of the upscaled image embedding.
Finally, a mask is predicted with a spatially point-wise product between the upscaled image embedding and the MLP's output.
Arguments
256
.1 + num_multimask_outputs
. Defaults
to 3
.3
.256
."gelu"
.