Successive halving¶
This example shows how to compare multiple tree-based models using successive halving.
Import the california housing dataset from sklearn.datasets. This is a small and easy to train dataset whose goal is to predict house prices.
Load the data¶
In [1]:
Copied!
from sklearn.datasets import fetch_california_housing
from atom import ATOMRegressor
from sklearn.datasets import fetch_california_housing
from atom import ATOMRegressor
In [2]:
Copied!
# Load the data
X, y = fetch_california_housing(return_X_y=True)
# Load the data
X, y = fetch_california_housing(return_X_y=True)
Run the pipeline¶
In [3]:
Copied!
atom = ATOMRegressor(X, y, verbose=2, random_state=1)
atom = ATOMRegressor(X, y, verbose=2, random_state=1)
<< ================== ATOM ================== >> Algorithm task: regression. Dataset stats ==================== >> Shape: (20640, 9) Memory: 1.49 MB Scaled: False Outlier values: 799 (0.5%) ------------------------------------- Train set size: 16512 Test set size: 4128
In [4]:
Copied!
# Compare tree-based models via successive halving
atom.successive_halving(
models=["Tree", "Bag", "ET", "RF", "LGB", "CatB"],
metric="mae",
n_bootstrap=5,
)
# Compare tree-based models via successive halving
atom.successive_halving(
models=["Tree", "Bag", "ET", "RF", "LGB", "CatB"],
metric="mae",
n_bootstrap=5,
)
Run: 0 ================================ >> Models: Tree6, Bag6, ET6, RF6, LGB6, CatB6 Size of training set: 16512 (17%) Size of test set: 4128 Training ========================= >> Metric: neg_mean_absolute_error Results for Decision Tree: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.0 Test evaluation --> neg_mean_absolute_error: -0.5598 Time elapsed: 0.029s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.546 ± 0.0081 Time elapsed: 0.276s ------------------------------------------------- Total time: 0.305s Results for Bagging: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.1649 Test evaluation --> neg_mean_absolute_error: -0.4074 Time elapsed: 0.130s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.4145 ± 0.0081 Time elapsed: 0.740s ------------------------------------------------- Total time: 0.870s Results for Extra-Trees: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.0 Test evaluation --> neg_mean_absolute_error: -0.3749 Time elapsed: 0.635s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.3834 ± 0.0008 Time elapsed: 2.894s ------------------------------------------------- Total time: 3.531s Results for Random Forest: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.1457 Test evaluation --> neg_mean_absolute_error: -0.3905 Time elapsed: 1.185s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.3969 ± 0.0046 Time elapsed: 5.579s ------------------------------------------------- Total time: 6.766s Results for LightGBM: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.1945 Test evaluation --> neg_mean_absolute_error: -0.3457 Time elapsed: 0.176s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.3549 ± 0.0024 Time elapsed: 0.729s ------------------------------------------------- Total time: 0.905s Results for CatBoost: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.1637 Test evaluation --> neg_mean_absolute_error: -0.3319 Time elapsed: 3.576s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.341 ± 0.0008 Time elapsed: 17.036s ------------------------------------------------- Total time: 20.612s Final results ==================== >> Duration: 32.989s ------------------------------------- Decision Tree --> neg_mean_absolute_error: -0.546 ± 0.0081 ~ Bagging --> neg_mean_absolute_error: -0.4145 ± 0.0081 ~ Extra-Trees --> neg_mean_absolute_error: -0.3834 ± 0.0008 ~ Random Forest --> neg_mean_absolute_error: -0.3969 ± 0.0046 ~ LightGBM --> neg_mean_absolute_error: -0.3549 ± 0.0024 ~ CatBoost --> neg_mean_absolute_error: -0.341 ± 0.0008 ~ ! Run: 1 ================================ >> Models: ET3, LGB3, CatB3 Size of training set: 16512 (33%) Size of test set: 4128 Training ========================= >> Metric: neg_mean_absolute_error Results for Extra-Trees: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.0 Test evaluation --> neg_mean_absolute_error: -0.3575 Time elapsed: 1.231s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.3685 ± 0.0033 Time elapsed: 5.138s ------------------------------------------------- Total time: 6.372s Results for LightGBM: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.231 Test evaluation --> neg_mean_absolute_error: -0.3278 Time elapsed: 0.224s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.3359 ± 0.0032 Time elapsed: 1.020s ------------------------------------------------- Total time: 1.245s Results for CatBoost: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.1954 Test evaluation --> neg_mean_absolute_error: -0.3138 Time elapsed: 3.758s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.3221 ± 0.0019 Time elapsed: 18.874s ------------------------------------------------- Total time: 22.632s Final results ==================== >> Duration: 30.250s ------------------------------------- Extra-Trees --> neg_mean_absolute_error: -0.3685 ± 0.0033 ~ LightGBM --> neg_mean_absolute_error: -0.3359 ± 0.0032 ~ CatBoost --> neg_mean_absolute_error: -0.3221 ± 0.0019 ~ ! Run: 2 ================================ >> Models: CatB1 Size of training set: 16512 (100%) Size of test set: 4128 Training ========================= >> Metric: neg_mean_absolute_error Results for CatBoost: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.232 Test evaluation --> neg_mean_absolute_error: -0.2932 Time elapsed: 5.425s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.3028 ± 0.0016 Time elapsed: 26.854s ------------------------------------------------- Total time: 32.279s Final results ==================== >> Duration: 32.279s ------------------------------------- CatBoost --> neg_mean_absolute_error: -0.3028 ± 0.0016 ~
Analyze results¶
In [5]:
Copied!
# The results is now multi-index, where frac is the fraction
# of the training set used to fit the model. The model names
# end with the number of models fitted during that run
atom.results
# The results is now multi-index, where frac is the fraction
# of the training set used to fit the model. The model names
# end with the number of models fitted during that run
atom.results
Out[5]:
metric_train | metric_test | time_fit | mean_bootstrap | std_bootstrap | time_bootstrap | time | ||
---|---|---|---|---|---|---|---|---|
frac | model | |||||||
0.17 | Bag6 | -1.649188e-01 | -0.407445 | 0.130s | -0.414546 | 0.008140 | 0.740s | 0.870s |
CatB6 | -2.320437e-01 | -0.293153 | 3.576s | -0.340974 | 0.000769 | 17.036s | 20.612s | |
ET6 | -2.544643e-15 | -0.357537 | 0.635s | -0.383449 | 0.000824 | 2.894s | 3.531s | |
LGB6 | -2.310124e-01 | -0.327832 | 0.176s | -0.354878 | 0.002439 | 0.729s | 0.905s | |
RF6 | -1.457109e-01 | -0.390478 | 1.185s | -0.396854 | 0.004553 | 5.579s | 6.766s | |
Tree6 | -3.324214e-17 | -0.559810 | 0.029s | -0.545994 | 0.008073 | 0.276s | 0.305s | |
0.33 | CatB3 | -2.320437e-01 | -0.293153 | 3.758s | -0.322100 | 0.001936 | 18.874s | 22.632s |
ET3 | -2.544643e-15 | -0.357537 | 1.231s | -0.368461 | 0.003350 | 5.138s | 6.372s | |
LGB3 | -2.310124e-01 | -0.327832 | 0.224s | -0.335865 | 0.003234 | 1.020s | 1.245s | |
1.00 | CatB1 | -2.320437e-01 | -0.293153 | 5.425s | -0.302825 | 0.001575 | 26.854s | 32.279s |
In [6]:
Copied!
# Plot the successive halving's results
atom.plot_successive_halving()
# Plot the successive halving's results
atom.plot_successive_halving()
In [7]:
Copied!
# Use an acronym to call all the models with the same estimator
atom.plot_errors(models=["CatB"])
# Use an acronym to call all the models with the same estimator
atom.plot_errors(models=["CatB"])
In [8]:
Copied!
# Use the number to call the models from the same run
atom.plot_errors(models="3")
# Use the number to call the models from the same run
atom.plot_errors(models="3")