Tuner
classkeras_tuner.Tuner(
oracle,
hypermodel=None,
max_model_size=None,
optimizer=None,
loss=None,
metrics=None,
distribution_strategy=None,
directory=None,
project_name=None,
logger=None,
tuner_id=None,
overwrite=False,
executions_per_trial=1,
**kwargs
)
Tuner class for Keras models.
This is the base Tuner
class for all tuners for Keras models. It manages
the building, training, evaluation and saving of the Keras models. New
tuners can be created by subclassing the class.
All Keras related logics are in Tuner.run_trial()
and its subroutines.
When subclassing Tuner
, if not calling super().run_trial()
, it can tune
anything.
Arguments
Oracle
class.HyperModel
class (or callable that takes
hyperparameters and returns a Model
instance). It is optional
when Tuner.run_trial()
is overriden and does not use
self.hypermodel
.optimizer
argument in the compile
step for the models. If the hypermodel
does not compile the models it generates, then this argument must be
specified.loss
argument in the
compile
step for the models. If the hypermodel does not compile
the models it generates, then this argument must be specified.metrics
argument in the compile
step for the models. If the hypermodel
does not compile the models it generates, then this argument must
be specified.tf.distribute.Strategy
.
If specified, each trial will run under this scope. For example,
tf.distribute.MirroredStrategy(['/gpu:0', '/gpu:1'])
will run
each trial on two GPUs. Currently only single-worker strategies are
supported.Tuner
.Tuner
.False
. If False
, reloads an
existing project of the same name if one is found. Otherwise,
overwrites the project.BaseTuner
.Attributes
None
if max_trials
is
not set. This is useful when resuming a previously stopped search.get_best_hyperparameters
methodTuner.get_best_hyperparameters(num_trials=1)
Returns the best hyperparameters, as determined by the objective.
This method can be used to reinstantiate the (untrained) best model found during the search process.
Example
best_hp = tuner.get_best_hyperparameters()[0]
model = tuner.hypermodel.build(best_hp)
Arguments
HyperParameters
objects to return.Returns
List of HyperParameter
objects sorted from the best to the worst.
get_best_models
methodTuner.get_best_models(num_models=1)
Returns the best model(s), as determined by the tuner's objective.
The models are loaded with the weights corresponding to their best checkpoint (at the end of the best epoch of best trial).
This method is for querying the models trained during the search.
For best performance, it is recommended to retrain your Model on the
full dataset using the best hyperparameters found during search
,
which can be obtained using tuner.get_best_hyperparameters()
.
Arguments
Returns
List of trained model instances sorted from the best to the worst.
get_state
methodTuner.get_state()
Returns the current state of this object.
This method is called during save
.
Returns
A dictionary of serializable objects as the state.
load_model
methodTuner.load_model(trial)
Loads a Model from a given trial.
For models that report intermediate results to the Oracle
, generally
load_model
should load the best reported step
by relying of
trial.best_step
.
Arguments
Trial
instance, the Trial
corresponding to the model
to load.on_epoch_begin
methodTuner.on_epoch_begin(trial, model, epoch, logs=None)
Called at the beginning of an epoch.
Arguments
Trial
instance.Model
.on_batch_begin
methodTuner.on_batch_begin(trial, model, batch, logs)
Called at the beginning of a batch.
Arguments
Trial
instance.Model
.on_batch_end
methodTuner.on_batch_end(trial, model, batch, logs=None)
Called at the end of a batch.
Arguments
Trial
instance.Model
.on_epoch_end
methodTuner.on_epoch_end(trial, model, epoch, logs=None)
Called at the end of an epoch.
Arguments
Trial
instance.Model
.run_trial
methodTuner.run_trial(trial, )
Evaluates a set of hyperparameter values.
This method is called multiple times during search
to build and
evaluate the models with different hyperparameters and return the
objective value.
Example
You can use it with self.hypermodel
to build and fit the model.
def run_trial(self, trial, *args, **kwargs):
hp = trial.hyperparameters
model = self.hypermodel.build(hp)
return self.hypermodel.fit(hp, model, *args, **kwargs)
You can also use it as a black-box optimizer for anything.
def run_trial(self, trial, *args, **kwargs):
hp = trial.hyperparameters
x = hp.Float("x", -2.0, 2.0)
y = x * x + 2 * x + 1
return y
Arguments
Trial
instance that contains the information needed to
run this trial. Hyperparameters can be accessed via
trial.hyperparameters
.search
.search
.Returns
A History
object, which is the return value of model.fit()
, a
dictionary, a float, or a list of one of these types.
If return a dictionary, it should be a dictionary of the metrics to
track. The keys are the metric names, which contains the
objective
name. The values should be the metric values.
If return a float, it should be the objective
value.
If evaluating the model for multiple times, you may return a list of results of any of the types above. The final objective value is the average of the results in the list.
results_summary
methodTuner.results_summary(num_trials=10)
Display tuning results summary.
The method prints a summary of the search results including the hyperparameter values and evaluation results for each trial.
Arguments
save_model
methodTuner.save_model(trial_id, model, step=0)
Saves a Model for a given trial.
Arguments
Trial
corresponding to this Model.Oracle
, the step the saved file correspond to. For example,
for Keras models this is the number of epochs trained.search
methodTuner.search(*fit_args, **fit_kwargs)
Performs a search for best hyperparameter configuations.
Arguments
run_trial
, for example the training and validation data.run_trial
, for example the training and validation data.search_space_summary
methodTuner.search_space_summary(extended=False)
Print search space summary.
The methods prints a summary of the hyperparameters in the search
space, which can be called before calling the search
method.
Arguments
set_state
methodTuner.set_state(state)
Sets the current state of this object.
This method is called during reload
.
Arguments