Authors: Tom O'Malley, Haifeng Jin
Date created: 2019/10/24
Last modified: 2021/06/02
Description: Tuning the hyperparameters of the models with multiple GPUs and multiple machines.
!pip install keras-tuner -q
KerasTuner makes it easy to perform distributed hyperparameter search. No changes to your code are needed to scale up from running single-threaded locally to running on dozens or hundreds of workers in parallel. Distributed KerasTuner uses a chief-worker model. The chief runs a service to which the workers report results and query for the hyperparameters to try next. The chief should be run on a single-threaded CPU instance (or alternatively as a separate process on one of the workers).
Configuring distributed mode for KerasTuner only requires setting three environment variables:
KERASTUNER_TUNER_ID: This should be set to "chief" for the chief process. Other workers should be passed a unique ID (by convention, "tuner0", "tuner1", etc).
KERASTUNER_ORACLE_IP: The IP address or hostname that the chief service should run on. All workers should be able to resolve and access this address.
KERASTUNER_ORACLE_PORT: The port that the chief service should run on. This can be freely chosen, but must be a port that is accessible to the other workers. Instances communicate via the gRPC protocol.
The same code can be run on all workers. Additional considerations for distributed mode are:
overwrite
should be kept as False
in
Tuner.__init__
(False
is the default).Example bash script for chief service (sample code for run_tuning.py
at
bottom of page):
export KERASTUNER_TUNER_ID="chief"
export KERASTUNER_ORACLE_IP="127.0.0.1"
export KERASTUNER_ORACLE_PORT="8000"
python run_tuning.py
Example bash script for worker:
export KERASTUNER_TUNER_ID="tuner0"
export KERASTUNER_ORACLE_IP="127.0.0.1"
export KERASTUNER_ORACLE_PORT="8000"
python run_tuning.py
tf.distribute
KerasTuner also supports data parallelism via tf.distribute. Data parallelism and distributed tuning can be combined. For example, if you have 10 workers with 4 GPUs on each worker, you can run 10 parallel trials with each trial training on 4 GPUs by using tf.distribute.MirroredStrategy. You can also run each trial on TPUs via tf.distribute.TPUStrategy. Currently tf.distribute.MultiWorkerMirroredStrategy is not supported, but support for this is on the roadmap.
When the environment variables described above are set, the example below will
run distributed tuning and use data parallelism within each trial via
tf.distribute
. The example loads MNIST from tensorflow_datasets
and uses
Hyperband for the hyperparameter
search.
import keras
import keras_tuner
import tensorflow as tf
import numpy as np
def build_model(hp):
"""Builds a convolutional model."""
inputs = keras.Input(shape=(28, 28, 1))
x = inputs
for i in range(hp.Int("conv_layers", 1, 3, default=3)):
x = keras.layers.Conv2D(
filters=hp.Int("filters_" + str(i), 4, 32, step=4, default=8),
kernel_size=hp.Int("kernel_size_" + str(i), 3, 5),
activation="relu",
padding="same",
)(x)
if hp.Choice("pooling" + str(i), ["max", "avg"]) == "max":
x = keras.layers.MaxPooling2D()(x)
else:
x = keras.layers.AveragePooling2D()(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.ReLU()(x)
if hp.Choice("global_pooling", ["max", "avg"]) == "max":
x = keras.layers.GlobalMaxPooling2D()(x)
else:
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(10, activation="softmax")(x)
model = keras.Model(inputs, outputs)
optimizer = hp.Choice("optimizer", ["adam", "sgd"])
model.compile(
optimizer, loss="sparse_categorical_crossentropy", metrics=["accuracy"]
)
return model
tuner = keras_tuner.Hyperband(
hypermodel=build_model,
objective="val_accuracy",
max_epochs=2,
factor=3,
hyperband_iterations=1,
distribution_strategy=tf.distribute.MirroredStrategy(),
directory="results_dir",
project_name="mnist",
overwrite=True,
)
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Reshape the images to have the channel dimension.
x_train = (x_train.reshape(x_train.shape + (1,)) / 255.0)[:1000]
y_train = y_train.astype(np.int64)[:1000]
x_test = (x_test.reshape(x_test.shape + (1,)) / 255.0)[:100]
y_test = y_test.astype(np.int64)[:100]
tuner.search(
x_train,
y_train,
steps_per_epoch=600,
validation_data=(x_test, y_test),
validation_steps=100,
callbacks=[keras.callbacks.EarlyStopping("val_accuracy")],
)
Trial 2 Complete [00h 00m 18s]
val_accuracy: 0.07000000029802322
Best val_accuracy So Far: 0.07000000029802322
Total elapsed time: 00h 00m 26s