Example: Pruning¶
This example shows an advanced example on how to use hyperparameter tuning with pruning.
Import the breast cancer dataset from sklearn.datasets. This is a small and easy to train dataset whose goal is to predict whether a patient has breast cancer or not.
Load the data¶
In [1]:
Copied!
# Import packages
from sklearn.datasets import load_breast_cancer
from optuna.pruners import HyperbandPruner
from atom import ATOMClassifier
# Import packages
from sklearn.datasets import load_breast_cancer
from optuna.pruners import HyperbandPruner
from atom import ATOMClassifier
In [2]:
Copied!
# Load the data
X, y = load_breast_cancer(return_X_y=True)
# Load the data
X, y = load_breast_cancer(return_X_y=True)
Run the pipeline¶
In [3]:
Copied!
# Initialize atom
atom = ATOMClassifier(X, y, verbose=2, random_state=1)
# Initialize atom
atom = ATOMClassifier(X, y, verbose=2, random_state=1)
<< ================== ATOM ================== >> Configuration ==================== >> Algorithm task: Binary classification. Dataset stats ==================== >> Shape: (569, 31) Train set size: 456 Test set size: 113 ------------------------------------- Memory: 141.24 kB Scaled: False Outlier values: 167 (1.2%)
In [4]:
Copied!
# Use ht_params to specify a custom pruner
# Note that pruned trials show the number of iterations it completed
atom.run(
models="SGD",
metric="f1",
n_trials=25,
ht_params={
"distributions": ["penalty", "max_iter"],
"pruner": HyperbandPruner(),
}
)
# Use ht_params to specify a custom pruner
# Note that pruned trials show the number of iterations it completed
atom.run(
models="SGD",
metric="f1",
n_trials=25,
ht_params={
"distributions": ["penalty", "max_iter"],
"pruner": HyperbandPruner(),
}
)
Training ========================= >> Models: SGD Metric: f1 Running hyperparameter tuning for StochasticGradientDescent... | trial | penalty | max_iter | f1 | best_f1 | time_trial | time_ht | state | | ----- | ------- | -------- | ------- | ------- | ---------- | ------- | -------- | | 0 | l1 | 650 | 0.9558 | 0.9558 | 3.786s | 3.786s | COMPLETE | | 1 | elast.. | 1050 | 0.9744 | 0.9744 | 8.431s | 12.217s | COMPLETE | | 2 | elast.. | 500 | 0.9828 | 0.9828 | 0.040s | 12.257s | PRUNED | | 3 | None | 700 | 0.9739 | 0.9828 | 3.873s | 16.130s | COMPLETE | | 4 | l1 | 1400 | 0.9735 | 0.9828 | 0.047s | 16.177s | PRUNED | | 5 | None | 1400 | 0.9735 | 0.9828 | 7.163s | 23.340s | COMPLETE | | 6 | l2 | 1200 | 0.9825 | 0.9828 | 5.103s | 28.443s | COMPLETE | | 7 | l2 | 1250 | 0.9825 | 0.9828 | 5.729s | 34.172s | COMPLETE | | 8 | None | 600 | 0.9828 | 0.9828 | 0.034s | 34.206s | PRUNED | | 9 | l1 | 600 | 0.9402 | 0.9828 | 0.042s | 34.249s | PRUNED | | 10 | l2 | 950 | 0.9565 | 0.9828 | 4.432s | 38.681s | COMPLETE | | 11 | l2 | 1200 | 0.9825 | 0.9828 | 0.005s | 38.686s | COMPLETE | | 12 | l2 | 1200 | 0.9825 | 0.9828 | 0.004s | 38.690s | COMPLETE | | 13 | l2 | 1200 | 0.9825 | 0.9828 | 0.005s | 38.695s | COMPLETE | | 14 | l2 | 1500 | 0.9573 | 0.9828 | 0.050s | 38.745s | PRUNED | | 15 | l2 | 950 | 0.9565 | 0.9828 | 0.005s | 38.750s | COMPLETE | | 16 | l2 | 1100 | 0.9391 | 0.9828 | 0.044s | 38.795s | PRUNED | | 17 | l2 | 850 | 0.9831 | 0.9831 | 0.047s | 38.842s | PRUNED | | 18 | elast.. | 1300 | 0.931 | 0.9831 | 0.050s | 38.892s | PRUNED | | 19 | l2 | 1300 | 0.9649 | 0.9831 | 0.072s | 38.963s | PRUNED | | 20 | l2 | 800 | 0.9661 | 0.9831 | 0.044s | 39.007s | PRUNED | | 21 | l2 | 1150 | 0.9402 | 0.9831 | 0.039s | 39.046s | PRUNED | | 22 | l2 | 1300 | 0.9573 | 0.9831 | 0.050s | 39.096s | PRUNED | | 23 | l2 | 1250 | 0.9825 | 0.9831 | 0.006s | 39.102s | COMPLETE | | 24 | l2 | 1050 | 0.9565 | 0.9831 | 0.096s | 39.198s | PRUNED | Hyperparameter tuning --------------------------- Best trial --> 6 Best parameters: --> penalty: l2 --> max_iter: 1200 Best evaluation --> f1: 0.9825 Time elapsed: 39.198s Fit --------------------------------------------- Train evaluation --> f1: 0.993 Test evaluation --> f1: 0.9722 Time elapsed: 10.046s ------------------------------------------------- Time: 49.244s Final results ==================== >> Total time: 49.452s ------------------------------------- StochasticGradientDescent --> f1: 0.9722
Analyze the results¶
In [5]:
Copied!
atom.plot_trials()
atom.plot_trials()
In [7]:
Copied!
atom.plot_timeline()
atom.plot_timeline()
In [6]:
Copied!
atom.plot_hyperparameter_importance()
atom.plot_hyperparameter_importance()