Example: Pruning¶
This example shows an advanced example on how to use hyperparameter tuning with pruning.
Import the breast cancer dataset from sklearn.datasets. This is a small and easy to train dataset whose goal is to predict whether a patient has breast cancer or not.
Load the data¶
In [1]:
Copied!
# Import packages
from sklearn.datasets import load_breast_cancer
from optuna.pruners import HyperbandPruner
from atom import ATOMClassifier
# Import packages
from sklearn.datasets import load_breast_cancer
from optuna.pruners import HyperbandPruner
from atom import ATOMClassifier
UserWarning: The pandas version installed (1.5.3) does not match the supported pandas version in Modin (1.5.2). This may cause undesired side effects!
In [2]:
Copied!
# Load the data
X, y = load_breast_cancer(return_X_y=True)
# Load the data
X, y = load_breast_cancer(return_X_y=True)
Run the pipeline¶
In [3]:
Copied!
# Initialize atom
atom = ATOMClassifier(X, y, verbose=2, random_state=1)
# Initialize atom
atom = ATOMClassifier(X, y, verbose=2, random_state=1)
<< ================== ATOM ================== >> Algorithm task: binary classification. Dataset stats ==================== >> Shape: (569, 31) Train set size: 456 Test set size: 113 ------------------------------------- Memory: 141.24 kB Scaled: False Outlier values: 167 (1.2%)
In [4]:
Copied!
# Use ht_params to specify a custom pruner
# Note that pruned trials show the number of iterations it completed
atom.run(
models="SGD",
metric="f1",
n_trials=25,
ht_params={
"distributions": ["penalty", "max_iter"],
"pruner": HyperbandPruner(),
}
)
# Use ht_params to specify a custom pruner
# Note that pruned trials show the number of iterations it completed
atom.run(
models="SGD",
metric="f1",
n_trials=25,
ht_params={
"distributions": ["penalty", "max_iter"],
"pruner": HyperbandPruner(),
}
)
Training ========================= >> Models: SGD Metric: f1 Running hyperparameter tuning for StochasticGradientDescent... | trial | penalty | max_iter | f1 | best_f1 | time_trial | time_ht | state | | ----- | ------- | -------- | ------- | ------- | ---------- | ------- | -------- | | 0 | l1 | 650 | 0.9735 | 0.9735 | 2.369s | 2.369s | COMPLETE | | 1 | elast.. | 1050 | 0.9739 | 0.9739 | 4.010s | 6.380s | COMPLETE | | 2 | elast.. | 27/500 | 0.9825 | 0.9825 | 0.130s | 6.510s | PRUNED | | 3 | None | 700 | 0.9825 | 0.9825 | 2.764s | 9.273s | COMPLETE | | 4 | l1 | 3/1400 | 0.9821 | 0.9825 | 0.039s | 9.312s | PRUNED | | 5 | None | 81/1400 | 0.9821 | 0.9825 | 0.318s | 9.631s | PRUNED | | 6 | l2 | 1200 | 0.9821 | 0.9825 | 4.592s | 14.223s | COMPLETE | | 7 | l2 | 1250 | 0.973 | 0.9825 | 4.491s | 18.714s | COMPLETE | | 8 | None | 81/600 | 0.9636 | 0.9825 | 0.314s | 19.028s | PRUNED | | 9 | l1 | 3/600 | 0.9655 | 0.9825 | 0.031s | 19.059s | PRUNED | | 10 | l1 | 1000 | 0.9649 | 0.9825 | 3.803s | 22.862s | COMPLETE | | 11 | elast.. | 1/1200 | 0.955 | 0.9825 | 0.023s | 22.885s | PRUNED | | 12 | l2 | 9/550 | 0.9636 | 0.9825 | 0.049s | 22.934s | PRUNED | | 13 | elast.. | 1/1100 | 0.9636 | 0.9825 | 0.022s | 22.956s | PRUNED | | 14 | None | 3/900 | 0.9565 | 0.9825 | 0.029s | 22.985s | PRUNED | | 15 | l1 | 1/1250 | 0.9735 | 0.9825 | 0.023s | 23.008s | PRUNED | | 16 | None | 3/750 | 0.9558 | 0.9825 | 0.028s | 23.036s | PRUNED | | 17 | l2 | 1/1150 | 0.9558 | 0.9825 | 0.022s | 23.058s | PRUNED | | 18 | l1 | 1/900 | 0.8727 | 0.9825 | 0.024s | 23.082s | PRUNED | | 19 | l1 | 9/1100 | 0.9735 | 0.9825 | 0.052s | 23.134s | PRUNED | | 20 | l2 | 3/1450 | 0.95 | 0.9825 | 0.029s | 23.163s | PRUNED | | 21 | l2 | 1500 | 0.9828 | 0.9828 | 5.384s | 28.548s | COMPLETE | | 22 | l2 | 3/550 | 0.9735 | 0.9828 | 0.030s | 28.578s | PRUNED | | 23 | l2 | 1/600 | 0.9655 | 0.9828 | 0.028s | 28.606s | PRUNED | | 24 | elast.. | 9/1400 | 0.9558 | 0.9828 | 0.051s | 28.657s | PRUNED | Hyperparameter tuning --------------------------- Best trial --> 21 Best parameters: --> penalty: l2 --> max_iter: 1500 Best evaluation --> f1: 0.9828 Time elapsed: 28.657s Fit --------------------------------------------- Train evaluation --> f1: 0.9948 Test evaluation --> f1: 0.9722 Time elapsed: 7.966s ------------------------------------------------- Total time: 36.623s Final results ==================== >> Total time: 36.691s ------------------------------------- StochasticGradientDescent --> f1: 0.9722
Analyze the results¶
In [5]:
Copied!
atom.plot_trials()
atom.plot_trials()
In [6]:
Copied!
atom.plot_hyperparameter_importance()
atom.plot_hyperparameter_importance()