RetinaNet
classkeras_cv.models.RetinaNet(
backbone,
num_classes,
bounding_box_format,
anchor_generator=None,
label_encoder=None,
prediction_decoder=None,
feature_pyramid=None,
classification_head=None,
box_head=None,
**kwargs
)
A Keras model implementing the RetinaNet meta-architecture.
Implements the RetinaNet architecture for object detection. The constructor
requires num_classes
, bounding_box_format
, and a backbone. Optionally,
a custom label encoder, and prediction decoder may be provided.
Example
images = np.ones((1, 512, 512, 3))
labels = {
"boxes": tf.cast([
[
[0, 0, 100, 100],
[100, 100, 200, 200],
[300, 300, 100, 100],
]
], dtype=tf.float32),
"classes": tf.cast([[1, 1, 1]], dtype=tf.float32),
}
model = keras_cv.models.RetinaNet(
num_classes=20,
bounding_box_format="xywh",
backbone=keras_cv.models.ResNet50Backbone.from_preset(
"resnet50_imagenet"
)
)
# Evaluate model without box decoding and NMS
model(images)
# Prediction with box decoding and NMS
model.predict(images)
# Train model
model.compile(
classification_loss='focal',
box_loss='smoothl1',
optimizer=keras.optimizers.SGD(global_clipnorm=10.0),
jit_compile=False,
)
model.fit(images, labels)
Arguments
keras.Model
. If the default feature_pyramid
is used,
must implement the pyramid_level_inputs
property with keys "P3", "P4",
and "P5" and layer names as values. A somewhat sensible backbone
to use in many cases is the:
keras_cv.models.ResNetBackbone.from_preset("resnet50_imagenet")
keras_cv.layers.AnchorGenerator
. If
provided, the anchor generator will be passed to both the
label_encoder
and the prediction_decoder
. Only to be used when
both label_encoder
and prediction_decoder
are both None
.
Defaults to an anchor generator with the parameterization:
strides=[2**i for i in range(3, 8)]
,
scales=[2**x for x in [0, 1 / 3, 2 / 3]]
,
sizes=[32.0, 64.0, 128.0, 256.0, 512.0]
,
and aspect_ratios=[0.5, 1.0, 2.0]
.call()
method, and returns RetinaNet training targets. By default, a
KerasCV standard RetinaNetLabelEncoder
is created and used.
Results of this object's call()
method are passed to the loss
object for box_loss
and classification_loss
the y_true
argument.keras.layers.Layer
that is
responsible for transforming RetinaNet predictions into usable
bounding box Tensors. If not provided, a default is provided. The
default prediction_decoder
layer is a
keras_cv.layers.MultiClassNonMaxSuppression
layer, which uses
a Non-Max Suppression for box pruning.keras.layers.Layer
that produces
a list of 4D feature maps (batch dimension included)
when called on the pyramid-level outputs of the backbone
.
If not provided, the reference implementation from the paper will be used.keras.Layer
that performs
classification of the bounding boxes. If not provided, a simple
ConvNet with 3 layers will be used.keras.Layer
that performs regression of the
bounding boxes. If not provided, a simple ConvNet with 3 layers
will be used.from_preset
methodRetinaNet.from_preset()
Instantiate RetinaNet model from preset config and weights.
Arguments
None
, which follows whether the preset has
pretrained weights available.None
.If None
, the preset
value will be used.Example
# Load architecture and weights from preset
model = keras_cv.models.RetinaNet.from_preset(
"resnet50_imagenet",
)
# Load randomly initialized model from preset architecture with weights
model = keras_cv.models.RetinaNet.from_preset(
"resnet50_imagenet",
load_weights=False,
Preset name | Parameters | Description |
---|---|---|
retinanet_resnet50_pascalvoc | 35.60M | RetinaNet with a ResNet50 v1 backbone. Trained on PascalVOC 2012 object detection task, which consists of 20 classes. This model achieves a final MaP of 0.33 on the evaluation set. |
PredictionHead
classkeras_cv.models.retinanet.PredictionHead(
output_filters, bias_initializer, num_conv_layers=3, **kwargs
)
The class/box predictions head.
Arguments
Returns
A function representing either the classification
or the box regression head depending on output_filters
.