Successive halving¶
This example shows how to compare multiple tree-based models using successive halving.
Import the boston dataset from sklearn.datasets. This is a small and easy to train dataset whose goal is to predict house prices.
Load the data¶
In [1]:
Copied!
from sklearn.datasets import load_boston
from atom import ATOMRegressor
from sklearn.datasets import load_boston
from atom import ATOMRegressor
In [2]:
Copied!
# Load the data
X, y = load_boston(return_X_y=True)
# Load the data
X, y = load_boston(return_X_y=True)
Run the pipeline¶
In [3]:
Copied!
atom = ATOMRegressor(X, y, verbose=2, random_state=1)
atom = ATOMRegressor(X, y, verbose=2, random_state=1)
<< ================== ATOM ================== >> Algorithm task: regression. Dataset stats ====================== >> Shape: (506, 14) Scaled: False Outlier values: 82 (1.4%) --------------------------------------- Train set size: 405 Test set size: 101
In [4]:
Copied!
# Compare tree-based models via successive halving
atom.successive_halving(
models=["Tree", "Bag", "ET", "RF", "LGB", "CatB"],
metric="mae",
n_bootstrap=5,
)
# Compare tree-based models via successive halving
atom.successive_halving(
models=["Tree", "Bag", "ET", "RF", "LGB", "CatB"],
metric="mae",
n_bootstrap=5,
)
Training ===================================== >> Metric: neg_mean_absolute_error Run: 0 ================================ >> Models: Tree6, Bag6, ET6, RF6, LGB6, CatB6 Size of training set: 405 (17%) Size of test set: 101 Results for Decision Tree: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.0 Test evaluation --> neg_mean_absolute_error: -3.3257 Time elapsed: 0.006s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -4.3307 ± 0.525 Time elapsed: 0.020s ------------------------------------------------- Total time: 0.026s Results for Bagging Regressor: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -1.3054 Test evaluation --> neg_mean_absolute_error: -2.695 Time elapsed: 0.021s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -3.0957 ± 0.2677 Time elapsed: 0.091s ------------------------------------------------- Total time: 0.112s Results for Extra-Trees: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.0 Test evaluation --> neg_mean_absolute_error: -2.1541 Time elapsed: 0.084s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -2.5554 ± 0.1708 Time elapsed: 0.359s ------------------------------------------------- Total time: 0.444s Results for Random Forest: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -1.1509 Test evaluation --> neg_mean_absolute_error: -2.4143 Time elapsed: 0.110s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -2.9574 ± 0.2253 Time elapsed: 0.494s ------------------------------------------------- Total time: 0.605s Results for LightGBM: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -3.3965 Test evaluation --> neg_mean_absolute_error: -4.4873 Time elapsed: 0.026s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -4.8485 ± 0.2679 Time elapsed: 0.070s ------------------------------------------------- Total time: 0.097s Results for CatBoost: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.0806 Test evaluation --> neg_mean_absolute_error: -2.3991 Time elapsed: 1.256s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -2.9193 ± 0.2604 Time elapsed: 4.204s ------------------------------------------------- Total time: 5.460s Final results ========================= >> Duration: 6.745s ------------------------------------------ Decision Tree --> neg_mean_absolute_error: -4.3307 ± 0.525 ~ Bagging Regressor --> neg_mean_absolute_error: -3.0957 ± 0.2677 ~ Extra-Trees --> neg_mean_absolute_error: -2.5554 ± 0.1708 ~ ! Random Forest --> neg_mean_absolute_error: -2.9574 ± 0.2253 ~ LightGBM --> neg_mean_absolute_error: -4.8485 ± 0.2679 ~ CatBoost --> neg_mean_absolute_error: -2.9193 ± 0.2604 ~ Run: 1 ================================ >> Models: ET3, RF3, CatB3 Size of training set: 405 (33%) Size of test set: 101 Results for Extra-Trees: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.0 Test evaluation --> neg_mean_absolute_error: -2.2361 Time elapsed: 0.097s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -2.6016 ± 0.289 Time elapsed: 0.397s ------------------------------------------------- Total time: 0.494s Results for Random Forest: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.982 Test evaluation --> neg_mean_absolute_error: -2.5055 Time elapsed: 0.127s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -2.7619 ± 0.1947 Time elapsed: 0.563s ------------------------------------------------- Total time: 0.690s Results for CatBoost: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.2835 Test evaluation --> neg_mean_absolute_error: -2.42 Time elapsed: 1.797s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -2.5595 ± 0.2768 Time elapsed: 6.725s ------------------------------------------------- Total time: 8.523s Final results ========================= >> Duration: 9.708s ------------------------------------------ Extra-Trees --> neg_mean_absolute_error: -2.6016 ± 0.289 ~ Random Forest --> neg_mean_absolute_error: -2.7619 ± 0.1947 ~ CatBoost --> neg_mean_absolute_error: -2.5595 ± 0.2768 ~ ! Run: 2 ================================ >> Models: CatB1 Size of training set: 405 (100%) Size of test set: 101 Results for CatBoost: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.3978 Test evaluation --> neg_mean_absolute_error: -1.8776 Time elapsed: 3.155s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -2.0515 ± 0.0902 Time elapsed: 15.202s ------------------------------------------------- Total time: 18.358s Final results ========================= >> Duration: 18.359s ------------------------------------------ CatBoost --> neg_mean_absolute_error: -2.0515 ± 0.0902 ~
Analyze results¶
In [5]:
Copied!
# The results is now multi-index, where frac is the fraction
# of the training set used to fit the model. The model names
# end with the number of models fitted during that run
atom.results
# The results is now multi-index, where frac is the fraction
# of the training set used to fit the model. The model names
# end with the number of models fitted during that run
atom.results
Out[5]:
metric_train | metric_test | time_fit | mean_bootstrap | std_bootstrap | time_bootstrap | time | ||
---|---|---|---|---|---|---|---|---|
frac | model | |||||||
0.17 | Bag6 | -1.305373e+00 | -2.695050 | 0.021s | -3.095663 | 0.267668 | 0.091s | 0.112s |
CatB6 | -8.055503e-02 | -2.399073 | 1.256s | -2.919304 | 0.260378 | 4.204s | 5.460s | |
ET6 | -2.256238e-14 | -2.154089 | 0.084s | -2.555434 | 0.170823 | 0.359s | 0.444s | |
LGB6 | -3.396511e+00 | -4.487270 | 0.026s | -4.848536 | 0.267874 | 0.070s | 0.097s | |
RF6 | -1.150866e+00 | -2.414297 | 0.110s | -2.957400 | 0.225311 | 0.494s | 0.605s | |
Tree6 | -0.000000e+00 | -3.325743 | 0.006s | -4.330693 | 0.525026 | 0.020s | 0.026s | |
0.33 | CatB3 | -2.835499e-01 | -2.420032 | 1.797s | -2.559497 | 0.276791 | 6.725s | 8.523s |
ET3 | -2.315185e-14 | -2.236079 | 0.097s | -2.601648 | 0.289034 | 0.397s | 0.494s | |
RF3 | -9.819778e-01 | -2.505465 | 0.127s | -2.761887 | 0.194678 | 0.563s | 0.690s | |
1.00 | CatB1 | -3.977985e-01 | -1.877590 | 3.155s | -2.051462 | 0.090227 | 15.202s | 18.358s |
In [6]:
Copied!
# Plot the successive halving's results
atom.plot_successive_halving()
# Plot the successive halving's results
atom.plot_successive_halving()
In [7]:
Copied!
# Use an acronym to call all the models with the same estimator
atom.plot_errors(models=["CatB"])
# Use an acronym to call all the models with the same estimator
atom.plot_errors(models=["CatB"])
In [8]:
Copied!
# Use the number to call the models from the same run
atom.plot_errors(models="3")
# Use the number to call the models from the same run
atom.plot_errors(models="3")