BASNet
classkeras_cv.models.BASNet(
backbone,
num_classes,
input_shape=(None, None, 3),
input_tensor=None,
include_rescaling=False,
projection_filters=64,
prediction_heads=None,
refinement_head=None,
**kwargs
)
A Keras model implementing the BASNet architecture for semantic segmentation.
References
Arguments
keras.Model
. The backbone network for the model that is
used as a feature extractor for BASNet prediction encoder. Currently
supported backbones are ResNet18 and ResNet34. Default backbone is
keras_cv.models.ResNet34Backbone()
(Note: Do not specify 'input_shape', 'input_tensor', or 'include_rescaling'
within the backbone. Please provide these while initializing the
'BASNet' model.)layers.Input()
)
to use as image input for the model.True
, inputs will be passed through a Rescaling(1/255.0)
layer.backbone
.keras.layers.Layer
defining
the prediction module head for the model. If not provided, a
default head is created with a Conv2D layer followed by resizing.keras.layers.Layer
defining the
refinement module head for the model. If not provided, a default
head is created with a Conv2D layer.Example
import keras_cv
images = np.ones(shape=(1, 288, 288, 3))
labels = np.zeros(shape=(1, 288, 288, 1))
# Note: Do not specify 'input_shape', 'input_tensor', or
# 'include_rescaling' within the backbone.
backbone = keras_cv.models.ResNet34Backbone()
model = keras_cv.models.segmentation.BASNet(
backbone=backbone,
num_classes=1,
input_shape=[288, 288, 3],
include_rescaling=False
)
# Evaluate model
output = model(images)
pred_labels = output[0]
# Train model
model.compile(
optimizer="adam",
loss=keras.losses.BinaryCrossentropy(from_logits=False),
metrics=["accuracy"],
)
model.fit(images, labels, epochs=3)
```
----
<span style="float:right;">[[source]](https://github.com/keras-team/keras-cv/tree/v0.9.0/keras_cv/src/models/task.py#L183)</span>
### `from_preset` method
```python
BASNet.from_preset()
Instantiate BASNet model from preset config and weights.
Arguments
None
, which follows whether the preset has
pretrained weights available.None
.If None
, the preset
value will be used.Example
# Load architecture and weights from preset
model = keras_cv.models.BASNet.from_preset(
"",
)
# Load randomly initialized model from preset architecture with weights
model = keras_cv.models.BASNet.from_preset(
"",
load_weights=False,
Preset name | Parameters | Description |
---|---|---|
basnet_resnet18 | 98.78M | BASNet with a ResNet18 v1 backbone. |
basnet_resnet34 | 108.90M | BASNet with a ResNet34 v1 backbone. |