Better and faster hyperparameter optimization with Dask
Scott Sievert wrote this post. The original post lives at https://stsievert.com/blog/2019/09/27/dask-hyperparam-opt/ with better styling. This work is supported by Anaconda, Inc.
Dask’s machine learning package, Dask-ML now implements Hyperband, an advanced “hyperparameter optimization” algorithm that performs rather well. This post will
- describe “hyperparameter optimization”, a common problem in machine learning
- describe Hyperband’s benefits and why it works
- show how to use Hyperband via example alongside performance comparisons
In this post, I’ll walk through a practical example and highlight key portions of the paper “Better and faster hyperparameter optimization with Dask”, which is also summarized in a ~25 minute SciPy 2019 talk.
Problem
Machine learning requires data, an untrained model and “hyperparameters”, parameters that are chosen before training begins that help with cohesion between the model and data. The user needs to specify values for these hyperparameters in order to use the model. A good example is adapting ridge regression or LASSO to the amount of noise in the data with the regularization parameter.1
Model performance strongly depends on the hyperparameters provided. A fairly complex example is with a particular visualization tool, t-SNE. This tool requires (at least) three hyperparameters and performance depends radically on the hyperparameters. In fact, the first section in “How to Use t-SNE Effectively” is titled “Those hyperparameters really matter”.
Finding good values for these hyperparameters is critical and has an entire Scikit-learn documentation page, “Tuning the hyperparameters of an estimator.” Briefly, finding decent values of hyperparameters is difficult and requires guessing or searching.
How can these hyperparameters be found quickly and efficiently with an advanced task scheduler like Dask? Parallelism will pose some challenges, but the Dask architecture enables some advanced algorithms.
Note: this post presumes knowledge of Dask basics. This material is covered in Dask’s documentation on Why Dask?, a ~15 minute video introduction to Dask, a video introduction to Dask-ML and a blog post I wrote on my first use of Dask.
Contributions
Dask-ML can quickly find high-performing hyperparameters. I will back this claim with intuition and experimental evidence.
Specifically, this is because Dask-ML now implements an algorithm introduced by Li et. al. in “Hyperband: A novel bandit-based approach to hyperparameter optimization”. Pairing of Dask and Hyperband enables some exciting new performance opportunities, especially because Hyperband has a simple implementation and Dask is an advanced task scheduler.2
Let’s go through the basics of Hyperband then illustrate its use and performance with an example. This will highlight some key points of the corresponding paper.
Hyperband basics
The motivation for Hyperband is to find high performing hyperparameters with minimal training. Given this goal, it makes sense to spend more time training high performing models – why waste more time training time a model if it’s done poorly in the past?
One method to spend more time on high performing models is to initialize many models, start training all of them, and then stop training low performing models before training is finished. That’s what Hyperband does. At the most basic level, Hyperband is a (principled) early-stopping scheme for RandomizedSearchCV.
Deciding when to stop the training of models depends on how strongly the training data effects the score. There are two extremes:
- when only the training data matter
- i.e., when the hyperparameters don’t influence the score at all
- when only the hyperparameters matter
- i.e., when the training data don’t influence the score at all
Hyperband balances these two extremes by sweeping over how frequently
models are stopped. This sweep allows a mathematical proof that Hyperband
will find the best model possible with minimal partial_fit
calls3.
Hyperband has significant parallelism because it has two “embarrassingly parallel” for-loops – Dask can exploit this. Hyperband has been implemented in Dask, specifically in Dask’s machine library Dask-ML.
How well does it perform? Let’s illustrate via example. Some setup is required before the performance comparison in Performance.
Example
Note: want to try HyperbandSearchCV
out yourself? Dask has an example use.
It can even be run in-browser!
I’ll illustrate with a synthetic example. Let’s build a dataset with 4 classes:
>>> from experiment import make_circles
>>> X, y = make_circles(n_classes=4, n_features=6, n_informative=2)
>>> scatter(X[:, :2], color=y)
Note: this content is pulled from stsievert/dask-hyperband-comparison, or makes slight modifications.
Let’s build a fully connected neural net with 24 neurons for classification:
>>> from sklearn.neural_network import MLPClassifier
>>> model = MLPClassifier()
Building the neural net with PyTorch is also possible4 (and what I used in development).
This neural net’s behavior is dictated by 6 hyperparameters. Only one controls
the model of the optimal architecture (hidden_layer_sizes
, the number of
neurons in each layer). The rest control finding the best model of that
architecture. Details on the hyperparameters are in the
Appendix.
>>> params = ... # details in appendix
>>> params.keys()
dict_keys(['hidden_layer_sizes', 'alpha', 'batch_size', 'learning_rate'
'learning_rate_init', 'power_t', 'momentum'])
>>> params["hidden_layer_sizes"] # always 24 neurons
[(24, ), (12, 12), (6, 6, 6, 6), (4, 4, 4, 4, 4, 4), (12, 6, 3, 3)]
I choose these hyperparameters to have a complex search space that mimics the searches performed for most neural networks. These searches typically involve hyperparameters like “dropout”, “learning rate”, “momentum” and “weight decay”.5 End users don’t care hyperparameters like these; they don’t change the model architecture, only finding the best model of a particular architecture.
How can high performing hyperparameter values be found quickly?
Finding the best parameters
First, let’s look at the parameters required for Dask-ML’s implementation
of Hyperband (which is in the class HyperbandSearchCV
).
Hyperband parameters: rule-of-thumb
HyperbandSearchCV
has two inputs:
max_iter
, which determines how many times to callpartial_fit
- the chunk size of the Dask array, which determines how many data each
partial_fit
call receives.
These fall out pretty naturally once it’s known how long to train the best model and very approximately how many parameters to sample:
n_examples = 50 * len(X_train) # 50 passes through dataset for best model
n_params = 299 # sample about 300 parameters
# inputs to hyperband
max_iter = n_params
chunk_size = n_examples // n_params
The inputs to this rule-of-thumb are exactly what the user cares about:
- a measure of how complex the search space is (via
n_params
) - how long to train the best model (via
n_examples
)
Notably, there’s no tradeoff between n_examples
and n_params
like with
Scikit-learn’s RandomizedSearchCV
because n_examples
is only for some
models, not for all models. There’s more details on this
rule-of-thumb in the “Notes” section of the HyperbandSearchCV
docs.
With these inputs a HyperbandSearchCV
object can easily be created.
Finding the best performing hyperparameters
This model selection algorithm Hyperband is implemented in the class
HyperbandSearchCV
. Let’s create an instance of that class:
>>> from dask_ml.model_selection import HyperbandSearchCV
>>>
>>> search = HyperbandSearchCV(
... est, params, max_iter=max_iter, aggressiveness=4
... )
aggressiveness
defaults to 3. aggressiveness=4
is chosen because this is an
initial search; I know nothing about how this search space. Then, this search
should be more aggressive in culling off bad models.
Hyperband hides some details from the user (which enables the mathematical
guarantees), specifically the details on the amount of training and
the number of models created. These details are available in the metadata
attribute:
>>> search.metadata["n_models"]
378
>>> search.metadata["partial_fit_calls"]
5721
Now that we have some idea on how long the computation will take, let’s ask it to find the best set of hyperparameters:
>>> from dask_ml.model_selection import train_test_split
>>> X_train, y_train, X_test, y_test = train_test_split(X, y)
>>>
>>> X_train = X_train.rechunk(chunk_size)
>>> y_train = y_train.rechunk(chunk_size)
>>>
>>> search.fit(X_train, y_train)
The dashboard will be active during this time6:
How well do these hyperparameters perform?
>>> search.best_score_
0.9019221418447483
HyperbandSearchCV
mirrors Scikit-learn’s API for RandomizedSearchCV, so it
has access to all the expected attributes and methods:
>>> search.best_params_
{"batch_size": 64, "hidden_layer_sizes": [6, 6, 6, 6], ...}
>>> search.score(X_test, y_test)
0.8989070100111217
>>> search.best_model_
MLPClassifier(...)
Details on the attributes and methods are in the HyperbandSearchCV documentation.
Performance
I ran this 200 times on my personal laptop with 4 cores. Let’s look at the distribution of final validation scores:
The “passive” comparison is really RandomizedSearchCV
configured so it takes
an equal amount of work as HyperbandSearchCV
. Let’s see how this does over
time:
This graph shows the mean score over the 200 runs with the solid line, and the shaded region represents the interquartile range. The dotted green line indicates the data required to train 4 models to completion. “Passes through the dataset” is a good proxy for “time to solution” because there are only 4 workers.
This graph shows that HyperbandSearchCV
will find parameters at least 3 times
quicker than RandomizedSearchCV
.
Dask opportunities
What opportunities does combining Hyperband and Dask create?
HyperbandSearchCV
has a lot of internal parallelism and Dask is an advanced task
scheduler.
The most obvious opportunity involves job prioritization. Hyperband fits many models in parallel and Dask might not have that workers available. This means some jobs have to wait for other jobs to finish. Of course, Dask can prioritize jobs7 and choose which models to fit first.
Let’s assign the priority for fitting a certain model to be the model’s most recent score. How does this prioritization scheme influence the score? Let’s compare the prioritization schemes in a single run of the 200 above:
These two lines are the same in every way except for the prioritization scheme. This graph compares the “high scores” prioritization scheme and the Dask’s default prioritization scheme (“fifo”).
This graph is certainly helped by the fact that is run with only 4 workers. Job priority does not matter if every job can be run right away (there’s nothing to assign priority too!).
Amenability to parallelism
How does Hyperband scale with the number of workers?
I ran another separate experiment to measure. This experiment is described more in the corresponding paper, but the relevant difference is that a PyTorch neural network is used through skorch instead of Scikit-learn’s MLPClassifier.
I ran the same experiment with a different number of Dask
workers.8 Here’s how HyperbandSearchCV
scales:
Training one model to completion requires 243 seconds (which is marked by the
white line). This is a comparison with patience
, which stops training models
if their scores aren’t increasing enough. Functionally, this is very useful
because the user might accidentally specify n_examples
to be too large.
It looks like the speedups start to saturate somewhere
between 16 and 24 workers, at least for this example.
Of course, patience
doesn’t work as well for a large number of
workers.9
Future work
There’s some ongoing pull requests to improve HyperbandSearchCV
. The most
significant of these involves tweaking some Hyperband internals so HyperbandSearchCV
works better with initial or very exploratory searches (dask/dask-ml #532).
The biggest improvement I see is treating dataset size as the scarce resource
that needs to be preserved instead of training time. This would allow
Hyperband to work with any model, instead of only models that implement
partial_fit
.
Serialization is an important part of the distributed Hyperband implementation
in HyperbandSearchCV
. Scikit-learn and PyTorch can easily handle this because
they support the Pickle protocol10, but
Keras/Tensorflow/MXNet present challenges. The use of HyperbandSearchCV
could
be increased by resolving this issue.
Appendix
I choose to tune 7 hyperparameters, which are
hidden_layer_sizes
, which controls the activation function used at each neuronalpha
, which controls the amount of regularization
More hyperparameters control finding the best neural network:
batch_size
, which controls the number of examples theoptimizer
uses to approximate the gradientlearning_rate
,learning_rate_init
,power_t
, which control some basic hyperparameters for the SGD optimizer I’ll be usingmomentum
, a more advanced hyperparameter for SGD with Nesterov’s momentum.
-
Which amounts to choosing
alpha
in Scikit-learn’s Ridge or LASSO ↩ -
To the best of my knowledge, this is the first implementation of Hyperband with an advanced task scheduler ↩
-
More accurately, Hyperband will find close to the best model possible with $N$
partial_fit
calls in expected score with high probability, where “close” means “within log terms of the upper bound on score”. For details, see Corollary 1 of the corresponding paper or Theorem 5 of Hyperband’s paper. ↩ -
There’s less tuning for adaptive step size methods like Adam or Adagrad, but they might under-perform on the test data (see “The Marginal Value of Adaptive Gradient Methods for Machine Learning”) ↩
-
But it probably won’t be this fast: the video is sped up by a factor of 3. ↩
-
See Dask’s documentation on Prioritizing Work ↩
-
Everything is the same between different runs: the hyperparameters sampled, the model’s internal random state, the data passed for fitting. Only the number of workers varies. ↩
-
There’s no time benefit to stopping jobs early if there are infinite workers; there’s never a queue of jobs waiting to be run ↩
-
“Pickle isn’t slow, it’s a protocol” by Matthew Rocklin ↩
blog comments powered by Disqus