set_distribution
functionkeras.distribution.set_distribution(value)
Set the distribution as the global distribution setting.
Arguments
Distribution
instance.distribution
functionkeras.distribution.distribution()
Retrieve the current distribution from global context.
list_devices
functionkeras.distribution.list_devices(device_type=None)
Return all the available devices based on the device type.
Note: in a distributed setting, global devices are returned.
Arguments
"cpu"
, "gpu"
or "tpu"
.
Defaults to "gpu"
or "tpu"
if available when
device_type
is not provided. Otherwise
will return the "cpu"
devices.Return: List of devices that are available for distribute computation.
initialize
functionkeras.distribution.initialize(
job_addresses=None, num_processes=None, process_id=None
)
Initialize the distribution system for multi-host/process setting.
Calling initialize
will prepare the backend for execution on multi-host
GPU or TPUs. It should be called before any computations.
Note that the parameters can also be injected via environment variables, which can be better controlled by the launch script at startup time. For certain backend that also rely on the environment variables to configure, Keras will properly forward them.
Arguments
None
, and the
backend will figure it out with the TPU environment variables. You
can also config this value via environment variable
KERAS_DISTRIBUTION_JOB_ADDRESSES
.None
, and the backend will figure it out with the TPU
environment variables. You can also configure this value via
environment variable KERAS_DISTRIBUTION_NUM_PROCESSES
.0
to num_processes - 1
. 0
will indicate
the current worker/process is the master/coordinate job. You can
also configure this value via environment variable
KERAS_DISTRIBUTION_PROCESS_ID
.Suppose there are two GPU processes, and process 0 is running at
address 10.0.0.1:1234
, and process 1 is running at address
10.0.0.2:2345
. To configure such cluster, you can run
- __
On process 0__:
keras.distribute.initialize(
job_addresses="10.0.0.1:1234,10.0.0.2:2345",
num_processes=2,
process_id=0)
keras.distribute.initialize(
job_addresses="10.0.0.1:1234,10.0.0.2:2345",
num_processes=2,
process_id=1)
os.environ[
"KERAS_DISTRIBUTION_JOB_ADDRESSES"] = "10.0.0.1:1234,10.0.0.2:2345"
os.environ["KERAS_DISTRIBUTION_NUM_PROCESSES"] = "2"
os.environ["KERAS_DISTRIBUTION_PROCESS_ID"] = "0"
keras.distribute.initialize()
os.environ[
"KERAS_DISTRIBUTION_JOB_ADDRESSES"] = "10.0.0.1:1234,10.0.0.2:2345"
os.environ["KERAS_DISTRIBUTION_NUM_PROCESSES"] = "2"
os.environ["KERAS_DISTRIBUTION_PROCESS_ID"] = "1"
keras.distribute.initialize()
Also note that for JAX backend, the job_addresses
can be further
reduced to just the master/coordinator address, which is
- __10.0.0.1__:1234
.