LayoutMap
classkeras.distribution.LayoutMap(device_mesh)
A dict-like object that maps string to TensorLayout
instances.
LayoutMap
uses a string as key and a TensorLayout
as value. There is a
behavior difference between a normal Python dict and this class. The string
key will be treated as a regex when retrieving the value. See the docstring
of get
for more details.
See below for a usage example. You can define the naming schema
of the TensorLayout
, and then retrieve the corresponding
TensorLayout
instance.
In the normal case, the key to query is usually the variable.path
, which
is the identifier of the variable.
As shortcut, tuple or list of axis names are also allowed when inserting
as value, and will be converted to TensorLayout
.
layout_map = LayoutMap(device_mesh)
layout_map['dense.*kernel'] = (None, 'model')
layout_map['dense.*bias'] = ('model',)
layout_map['conv2d.*kernel'] = (None, None, None, 'model')
layout_map['conv2d.*bias'] = ('model',)
layout_1 = layout_map['dense_1.kernel'] # layout_1 == layout_2d
layout_2 = layout_map['dense_1.bias'] # layout_2 == layout_1d
layout_3 = layout_map['dense_2.kernel'] # layout_3 == layout_2d
layout_4 = layout_map['dense_2.bias'] # layout_4 == layout_1d
layout_5 = layout_map['my_model/conv2d_123/kernel'] # layout_5 == layout_4d
layout_6 = layout_map['my_model/conv2d_123/bias'] # layout_6 == layout_1d
layout_7 = layout_map['my_model/conv3d_1/kernel'] # layout_7 == None
layout_8 = layout_map['my_model/conv3d_1/bias'] # layout_8 == None
Arguments
keras.distribution.DeviceMesh
instance.DeviceMesh
classkeras.distribution.DeviceMesh(shape, axis_names, devices=None)
A cluster of computation devices for distributed computation.
This API is aligned with jax.sharding.Mesh
and tf.dtensor.Mesh
, which
represents the computation devices in the global context.
See more details in jax.sharding.Mesh and tf.dtensor.Mesh.
Arguments
DeviceMesh
, e.g. (8,)
for a data parallel only distribution,
or (4, 2)
for a model+data parallel distribution.DeviceMesh
. The length of the axis_names
should match to
the rank of the shape
. The axis_names
will be used to
match/create the TensorLayout
when distribute the data and
variables.keras.distribution.list_devices()
.TensorLayout
classkeras.distribution.TensorLayout(axes, device_mesh=None)
A layout to apply to a tensor.
This API is aligned with jax.sharding.NamedSharding
and tf.dtensor.Layout
.
See more details in jax.sharding.NamedSharding and tf.dtensor.Layout.
Arguments
axis_names
in
a DeviceMesh
. For any dimensions that doesn't need any sharding,
A None
can be used a placeholder.DeviceMesh
that will be used to create
the layout. The actual mapping of tensor to physical device
is not known until the mesh is specified.distribute_tensor
functionkeras.distribution.distribute_tensor(tensor, layout)
Change the layout of a Tensor value in the jit function execution.
Arguments
TensorLayout
to be applied on the value.Returns
a new value with the specified tensor layout.