Example: Successive halving¶
This example shows how to compare multiple tree-based models using successive halving.
Import the california housing dataset from sklearn.datasets. This is a small and easy to train dataset whose goal is to predict house prices.
Load the data¶
In [1]:
Copied!
from sklearn.datasets import fetch_california_housing
from atom import ATOMRegressor
from sklearn.datasets import fetch_california_housing
from atom import ATOMRegressor
UserWarning: The pandas version installed (1.5.3) does not match the supported pandas version in Modin (1.5.2). This may cause undesired side effects!
In [2]:
Copied!
# Load the data
X, y = fetch_california_housing(return_X_y=True)
# Load the data
X, y = fetch_california_housing(return_X_y=True)
Run the pipeline¶
In [3]:
Copied!
atom = ATOMRegressor(X, y, verbose=2, random_state=1)
atom = ATOMRegressor(X, y, verbose=2, random_state=1)
<< ================== ATOM ================== >> Algorithm task: regression. Dataset stats ==================== >> Shape: (20640, 9) Train set size: 16512 Test set size: 4128 ------------------------------------- Memory: 1.49 MB Scaled: False Outlier values: 786 (0.5%)
In [4]:
Copied!
# Compare tree-based models via successive halving
atom.successive_halving(
models=["Tree", "Bag", "ET", "RF", "LGB", "CatB"],
metric="mae",
n_bootstrap=5,
)
# Compare tree-based models via successive halving
atom.successive_halving(
models=["Tree", "Bag", "ET", "RF", "LGB", "CatB"],
metric="mae",
n_bootstrap=5,
)
Training ========================= >> Metric: neg_mean_absolute_error Run: 0 =========================== >> Models: Tree6, Bag6, ET6, RF6, LGB6, CatB6 Size of training set: 16512 (17%) Size of test set: 4128 Results for DecisionTree: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.0 Test evaluation --> neg_mean_absolute_error: -0.5538 Time elapsed: 0.038s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.5542 ± 0.0284 Time elapsed: 0.279s ------------------------------------------------- Total time: 0.317s Results for Bagging: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.1648 Test evaluation --> neg_mean_absolute_error: -0.42 Time elapsed: 0.213s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.4289 ± 0.0055 Time elapsed: 1.064s ------------------------------------------------- Total time: 1.278s Results for ExtraTrees: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.0 Test evaluation --> neg_mean_absolute_error: -0.3958 Time elapsed: 0.740s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.4015 ± 0.0029 Time elapsed: 3.456s ------------------------------------------------- Total time: 4.196s Results for RandomForest: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.145 Test evaluation --> neg_mean_absolute_error: -0.4003 Time elapsed: 2.034s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.4052 ± 0.005 Time elapsed: 9.483s ------------------------------------------------- Total time: 11.517s Results for LightGBM: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.2022 Test evaluation --> neg_mean_absolute_error: -0.3599 Time elapsed: 0.309s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.3607 ± 0.0018 Time elapsed: 0.744s ------------------------------------------------- Total time: 1.053s Results for CatBoost: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.1591 Test evaluation --> neg_mean_absolute_error: -0.3433 Time elapsed: 4.476s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.3486 ± 0.0017 Time elapsed: 20.322s ------------------------------------------------- Total time: 24.798s Final results ==================== >> Total time: 43.169s ------------------------------------- DecisionTree --> neg_mean_absolute_error: -0.5542 ± 0.0284 ~ Bagging --> neg_mean_absolute_error: -0.4289 ± 0.0055 ~ ExtraTrees --> neg_mean_absolute_error: -0.4015 ± 0.0029 ~ RandomForest --> neg_mean_absolute_error: -0.4052 ± 0.005 ~ LightGBM --> neg_mean_absolute_error: -0.3607 ± 0.0018 ~ CatBoost --> neg_mean_absolute_error: -0.3486 ± 0.0017 ~ ! Run: 1 =========================== >> Models: ET3, LGB3, CatB3 Size of training set: 16512 (33%) Size of test set: 4128 Results for ExtraTrees: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.0 Test evaluation --> neg_mean_absolute_error: -0.3691 Time elapsed: 1.422s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.3828 ± 0.0053 Time elapsed: 6.067s ------------------------------------------------- Total time: 7.489s Results for LightGBM: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.2367 Test evaluation --> neg_mean_absolute_error: -0.3342 Time elapsed: 0.344s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.3451 ± 0.0056 Time elapsed: 0.916s ------------------------------------------------- Total time: 1.260s Results for CatBoost: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.1876 Test evaluation --> neg_mean_absolute_error: -0.3176 Time elapsed: 4.643s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.3296 ± 0.0024 Time elapsed: 21.537s ------------------------------------------------- Total time: 26.180s Final results ==================== >> Total time: 34.934s ------------------------------------- ExtraTrees --> neg_mean_absolute_error: -0.3828 ± 0.0053 ~ LightGBM --> neg_mean_absolute_error: -0.3451 ± 0.0056 ~ CatBoost --> neg_mean_absolute_error: -0.3296 ± 0.0024 ~ ! Run: 2 =========================== >> Models: CatB1 Size of training set: 16512 (100%) Size of test set: 4128 Results for CatBoost: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.2229 Test evaluation --> neg_mean_absolute_error: -0.2986 Time elapsed: 6.611s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.3091 ± 0.0026 Time elapsed: 32.724s ------------------------------------------------- Total time: 39.335s Final results ==================== >> Total time: 39.336s ------------------------------------- CatBoost --> neg_mean_absolute_error: -0.3091 ± 0.0026 ~
Analyze the results¶
In [5]:
Copied!
# The results is now multi-index, where frac is the fraction
# of the training set used to fit the model. The model names
# end with the number of models fitted during that run
atom.results
# The results is now multi-index, where frac is the fraction
# of the training set used to fit the model. The model names
# end with the number of models fitted during that run
atom.results
Out[5]:
score_train | score_test | time_fit | score_bootstrap | time_bootstrap | time | ||
---|---|---|---|---|---|---|---|
frac | model | ||||||
0.17 | Bag6 | -0.1648 | -0.4200 | 0.213194 | -0.428947 | 1.064390 | 1.277584 |
CatB6 | -0.1591 | -0.3433 | 4.475700 | -0.348576 | 20.322169 | 24.797869 | |
ET6 | -0.0000 | -0.3958 | 0.740051 | -0.401530 | 3.456043 | 4.196094 | |
LGB6 | -0.2022 | -0.3599 | 0.308791 | -0.360678 | 0.744363 | 1.053154 | |
RF6 | -0.1450 | -0.4003 | 2.033849 | -0.405193 | 9.483478 | 11.517327 | |
Tree6 | -0.0000 | -0.5538 | 0.038035 | -0.554181 | 0.278868 | 0.316903 | |
0.33 | CatB3 | -0.1876 | -0.3176 | 4.643221 | -0.329575 | 21.537070 | 26.180291 |
ET3 | -0.0000 | -0.3691 | 1.422265 | -0.382764 | 6.067180 | 7.489445 | |
LGB3 | -0.2367 | -0.3342 | 0.344359 | -0.345083 | 0.915865 | 1.260224 | |
1.00 | CatB1 | -0.2229 | -0.2986 | 6.611475 | -0.309112 | 32.723607 | 39.335082 |
In [6]:
Copied!
# Plot the successive halving's results
atom.plot_successive_halving()
# Plot the successive halving's results
atom.plot_successive_halving()
In [7]:
Copied!
# Use regex to call all the models with the same estimator...
atom.plot_errors(models=["CatB.*"])
# Use regex to call all the models with the same estimator...
atom.plot_errors(models=["CatB.*"])
In [8]:
Copied!
# ...or to call the models from the same run
atom.plot_errors(models=".*3")
# ...or to call the models from the same run
atom.plot_errors(models=".*3")