-
Notifications
You must be signed in to change notification settings - Fork 178
/
Copy pathdask_simple.py
49 lines (38 loc) · 1.65 KB
/
dask_simple.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
"""
Optuna example that runs optimization trials in parallel using Dask.
In this example, we perform hyperparameter optimization on a
RandomForestClassifier which is trained using the handwritten digits
dataset. Trials are run in parallel on a Dask cluster using Optuna's
DaskStorage integration.
To run this example:
$ python dask_simple.py
"""
import optuna
from dask.distributed import Client
from dask.distributed import wait
from sklearn.datasets import load_digits
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
def objective(trial):
X, y = load_digits(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=2)
max_depth = trial.suggest_int("max_depth", 2, 10)
n_estimators = trial.suggest_int("n_estimators", 1, 100)
clf = RandomForestClassifier(max_depth=max_depth, n_estimators=n_estimators)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
score = accuracy_score(y_test, y_pred)
return score
if __name__ == "__main__":
with Client() as client:
print(f"Dask dashboard is available at {client.dashboard_link}")
storage = optuna.integration.dask.DaskStorage()
study = optuna.create_study(storage=storage, direction="maximize")
# Submit 10 different optimization tasks, where each task runs 7 optimization trials
# for a total of 70 trials in all
futures = [
client.submit(study.optimize, objective, n_trials=7, pure=False) for _ in range(10)
]
wait(futures)
print(f"Best params: {study.best_params}")