Example: Successive halving¶
This example shows how to compare multiple tree-based models using successive halving.
Import the california housing dataset from sklearn.datasets. This is a small and easy to train dataset whose goal is to predict house prices.
Load the data¶
In [1]:
Copied!
from sklearn.datasets import fetch_california_housing
from atom import ATOMRegressor
from sklearn.datasets import fetch_california_housing
from atom import ATOMRegressor
In [2]:
Copied!
# Load the data
X, y = fetch_california_housing(return_X_y=True)
# Load the data
X, y = fetch_california_housing(return_X_y=True)
Run the pipeline¶
In [3]:
Copied!
atom = ATOMRegressor(X, y, verbose=2, random_state=1)
atom = ATOMRegressor(X, y, verbose=2, random_state=1)
<< ================== ATOM ================== >> Algorithm task: regression. Dataset stats ==================== >> Shape: (20640, 9) Memory: 1.49 MB Scaled: False Outlier values: 786 (0.5%) ------------------------------------- Train set size: 16512 Test set size: 4128
In [4]:
Copied!
# Compare tree-based models via successive halving
atom.successive_halving(
models=["Tree", "Bag", "ET", "RF", "LGB", "CatB"],
metric="mae",
n_bootstrap=5,
)
# Compare tree-based models via successive halving
atom.successive_halving(
models=["Tree", "Bag", "ET", "RF", "LGB", "CatB"],
metric="mae",
n_bootstrap=5,
)
Training ========================= >> Metric: neg_mean_absolute_error Run: 0 =========================== >> Models: Tree6, Bag6, ET6, RF6, LGB6, CatB6 Size of training set: 16512 (17%) Size of test set: 4128 Results for DecisionTree: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.0 Test evaluation --> neg_mean_absolute_error: -0.5538 Time elapsed: 0.036s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.5542 ± 0.0284 Time elapsed: 0.321s ------------------------------------------------- Total time: 0.357s Results for Bagging: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.1648 Test evaluation --> neg_mean_absolute_error: -0.42 Time elapsed: 0.146s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.4289 ± 0.0055 Time elapsed: 0.831s ------------------------------------------------- Total time: 0.977s Results for ExtraTrees: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.0 Test evaluation --> neg_mean_absolute_error: -0.3958 Time elapsed: 0.787s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.4015 ± 0.0029 Time elapsed: 3.544s ------------------------------------------------- Total time: 4.331s Results for RandomForest: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.145 Test evaluation --> neg_mean_absolute_error: -0.4003 Time elapsed: 1.320s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.4052 ± 0.005 Time elapsed: 6.386s ------------------------------------------------- Total time: 7.707s Results for LightGBM: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.2022 Test evaluation --> neg_mean_absolute_error: -0.3599 Time elapsed: 0.361s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.3607 ± 0.0018 Time elapsed: 0.848s ------------------------------------------------- Total time: 1.209s Results for CatBoost: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.1591 Test evaluation --> neg_mean_absolute_error: -0.3433 Time elapsed: 4.811s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.3486 ± 0.0017 Time elapsed: 21.200s ------------------------------------------------- Total time: 26.010s Final results ==================== >> Total time: 40.595s ------------------------------------- DecisionTree --> neg_mean_absolute_error: -0.5542 ± 0.0284 ~ Bagging --> neg_mean_absolute_error: -0.4289 ± 0.0055 ~ ExtraTrees --> neg_mean_absolute_error: -0.4015 ± 0.0029 ~ RandomForest --> neg_mean_absolute_error: -0.4052 ± 0.005 ~ LightGBM --> neg_mean_absolute_error: -0.3607 ± 0.0018 ~ CatBoost --> neg_mean_absolute_error: -0.3486 ± 0.0017 ~ ! Run: 1 =========================== >> Models: ET3, LGB3, CatB3 Size of training set: 16512 (33%) Size of test set: 4128 Results for ExtraTrees: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.0 Test evaluation --> neg_mean_absolute_error: -0.3691 Time elapsed: 1.399s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.3828 ± 0.0053 Time elapsed: 5.967s ------------------------------------------------- Total time: 7.367s Results for LightGBM: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.2367 Test evaluation --> neg_mean_absolute_error: -0.3342 Time elapsed: 0.366s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.3451 ± 0.0056 Time elapsed: 0.969s ------------------------------------------------- Total time: 1.336s Results for CatBoost: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.1876 Test evaluation --> neg_mean_absolute_error: -0.3176 Time elapsed: 4.778s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.3296 ± 0.0024 Time elapsed: 23.359s ------------------------------------------------- Total time: 28.137s Final results ==================== >> Total time: 36.844s ------------------------------------- ExtraTrees --> neg_mean_absolute_error: -0.3828 ± 0.0053 ~ LightGBM --> neg_mean_absolute_error: -0.3451 ± 0.0056 ~ CatBoost --> neg_mean_absolute_error: -0.3296 ± 0.0024 ~ ! Run: 2 =========================== >> Models: CatB1 Size of training set: 16512 (100%) Size of test set: 4128 Results for CatBoost: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.2229 Test evaluation --> neg_mean_absolute_error: -0.2986 Time elapsed: 7.226s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.3091 ± 0.0026 Time elapsed: 35.044s ------------------------------------------------- Total time: 42.269s Final results ==================== >> Total time: 42.270s ------------------------------------- CatBoost --> neg_mean_absolute_error: -0.3091 ± 0.0026 ~
Analyze the results¶
In [5]:
Copied!
# The results is now multi-index, where frac is the fraction
# of the training set used to fit the model. The model names
# end with the number of models fitted during that run
atom.results
# The results is now multi-index, where frac is the fraction
# of the training set used to fit the model. The model names
# end with the number of models fitted during that run
atom.results
Out[5]:
score_train | score_test | time_fit | score_bootstrap | time_bootstrap | time | ||
---|---|---|---|---|---|---|---|
frac | model | ||||||
0.17 | Bag6 | -0.1648 | -0.4200 | 0.146132 | -0.428947 | 0.830754 | 0.976886 |
CatB6 | -0.1591 | -0.3433 | 4.810551 | -0.348576 | 21.199617 | 26.010168 | |
ET6 | -0.0000 | -0.3958 | 0.786993 | -0.401530 | 3.544431 | 4.331424 | |
LGB6 | -0.2022 | -0.3599 | 0.361332 | -0.360678 | 0.847782 | 1.209114 | |
RF6 | -0.1450 | -0.4003 | 1.320198 | -0.405193 | 6.386311 | 7.706509 | |
Tree6 | -0.0000 | -0.5538 | 0.036031 | -0.554181 | 0.321293 | 0.357324 | |
0.33 | CatB3 | -0.1876 | -0.3176 | 4.777847 | -0.329575 | 23.359312 | 28.137159 |
ET3 | -0.0000 | -0.3691 | 1.399291 | -0.382764 | 5.967419 | 7.366710 | |
LGB3 | -0.2367 | -0.3342 | 0.366333 | -0.345083 | 0.969394 | 1.335727 | |
1.00 | CatB1 | -0.2229 | -0.2986 | 7.225810 | -0.309112 | 35.043598 | 42.269408 |
In [6]:
Copied!
# Plot the successive halving's results
atom.plot_successive_halving()
# Plot the successive halving's results
atom.plot_successive_halving()
In [9]:
Copied!
# Use regex to call all the models with the same estimator...
atom.plot_errors(models=["CatB.*"])
# Use regex to call all the models with the same estimator...
atom.plot_errors(models=["CatB.*"])
In [10]:
Copied!
# ...or to call the models from the same run
atom.plot_errors(models=".*3")
# ...or to call the models from the same run
atom.plot_errors(models=".*3")