Successive halving¶
This example shows how to compare multiple tree-based models using successive halving.
Import the california housing dataset from sklearn.datasets. This is a small and easy to train dataset whose goal is to predict house prices.
Load the data¶
In [1]:
Copied!
from sklearn.datasets import fetch_california_housing
from atom import ATOMRegressor
from sklearn.datasets import fetch_california_housing
from atom import ATOMRegressor
In [2]:
Copied!
# Load the data
X, y = fetch_california_housing(return_X_y=True)
# Load the data
X, y = fetch_california_housing(return_X_y=True)
Run the pipeline¶
In [3]:
Copied!
atom = ATOMRegressor(X, y, verbose=2, random_state=1)
atom = ATOMRegressor(X, y, verbose=2, random_state=1)
<< ================== ATOM ================== >> Algorithm task: regression. Dataset stats ==================== >> Shape: (20640, 9) Scaled: False Outlier values: 811 (0.5%) ------------------------------------- Train set size: 16512 Test set size: 4128 -------------------------------------
In [4]:
Copied!
# Compare tree-based models via successive halving
atom.successive_halving(
models=["Tree", "Bag", "ET", "RF", "LGB", "CatB"],
metric="mae",
n_bootstrap=5,
)
# Compare tree-based models via successive halving
atom.successive_halving(
models=["Tree", "Bag", "ET", "RF", "LGB", "CatB"],
metric="mae",
n_bootstrap=5,
)
Run: 0 ================================ >> Models: Tree6, Bag6, ET6, RF6, LGB6, CatB6 Size of training set: 16512 (17%) Size of test set: 4128 Training ========================= >> Metric: neg_mean_absolute_error Results for Decision Tree: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.0 Test evaluation --> neg_mean_absolute_error: -0.564 Time elapsed: 0.029s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.5626 ± 0.0194 Time elapsed: 0.280s ------------------------------------------------- Total time: 0.309s Results for Bagging Regressor: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.1716 Test evaluation --> neg_mean_absolute_error: -0.4253 Time elapsed: 0.129s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.4284 ± 0.0042 Time elapsed: 0.733s ------------------------------------------------- Total time: 0.863s Results for Extra-Trees: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.0 Test evaluation --> neg_mean_absolute_error: -0.3859 Time elapsed: 0.610s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.3921 ± 0.0015 Time elapsed: 2.743s ------------------------------------------------- Total time: 3.356s Results for Random Forest: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.1491 Test evaluation --> neg_mean_absolute_error: -0.3998 Time elapsed: 1.224s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.4074 ± 0.003 Time elapsed: 5.745s ------------------------------------------------- Total time: 6.971s Results for LightGBM: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.2019 Test evaluation --> neg_mean_absolute_error: -0.35 Time elapsed: 0.180s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.3602 ± 0.0031 Time elapsed: 0.734s ------------------------------------------------- Total time: 0.915s Results for CatBoost: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.1604 Test evaluation --> neg_mean_absolute_error: -0.3335 Time elapsed: 3.526s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.3439 ± 0.0016 Time elapsed: 17.254s ------------------------------------------------- Total time: 20.780s Final results ==================== >> Duration: 33.194s ------------------------------------- Decision Tree --> neg_mean_absolute_error: -0.5626 ± 0.0194 ~ Bagging Regressor --> neg_mean_absolute_error: -0.4284 ± 0.0042 ~ Extra-Trees --> neg_mean_absolute_error: -0.3921 ± 0.0015 ~ Random Forest --> neg_mean_absolute_error: -0.4074 ± 0.003 ~ LightGBM --> neg_mean_absolute_error: -0.3602 ± 0.0031 ~ CatBoost --> neg_mean_absolute_error: -0.3439 ± 0.0016 ~ ! Run: 1 ================================ >> Models: ET3, LGB3, CatB3 Size of training set: 16512 (33%) Size of test set: 4128 Training ========================= >> Metric: neg_mean_absolute_error Results for Extra-Trees: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.0 Test evaluation --> neg_mean_absolute_error: -0.3527 Time elapsed: 1.154s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.3674 ± 0.0014 Time elapsed: 5.003s ------------------------------------------------- Total time: 6.162s Results for LightGBM: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.235 Test evaluation --> neg_mean_absolute_error: -0.326 Time elapsed: 0.197s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.3353 ± 0.0018 Time elapsed: 0.889s ------------------------------------------------- Total time: 1.086s Results for CatBoost: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.1945 Test evaluation --> neg_mean_absolute_error: -0.314 Time elapsed: 3.866s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.3232 ± 0.0003 Time elapsed: 19.343s ------------------------------------------------- Total time: 23.209s Final results ==================== >> Duration: 30.459s ------------------------------------- Extra-Trees --> neg_mean_absolute_error: -0.3674 ± 0.0014 ~ LightGBM --> neg_mean_absolute_error: -0.3353 ± 0.0018 ~ CatBoost --> neg_mean_absolute_error: -0.3232 ± 0.0003 ~ ! Run: 2 ================================ >> Models: CatB1 Size of training set: 16512 (100%) Size of test set: 4128 Training ========================= >> Metric: neg_mean_absolute_error Results for CatBoost: Fit --------------------------------------------- Train evaluation --> neg_mean_absolute_error: -0.2317 Test evaluation --> neg_mean_absolute_error: -0.2911 Time elapsed: 5.315s Bootstrap --------------------------------------- Evaluation --> neg_mean_absolute_error: -0.2989 ± 0.001 Time elapsed: 26.528s ------------------------------------------------- Total time: 31.843s Final results ==================== >> Duration: 31.844s ------------------------------------- CatBoost --> neg_mean_absolute_error: -0.2989 ± 0.001 ~
Analyze results¶
In [5]:
Copied!
# The results is now multi-index, where frac is the fraction
# of the training set used to fit the model. The model names
# end with the number of models fitted during that run
atom.results
# The results is now multi-index, where frac is the fraction
# of the training set used to fit the model. The model names
# end with the number of models fitted during that run
atom.results
Out[5]:
metric_train | metric_test | time_fit | mean_bootstrap | std_bootstrap | time_bootstrap | time | ||
---|---|---|---|---|---|---|---|---|
frac | model | |||||||
0.17 | Bag6 | -1.716132e-01 | -0.425329 | 0.129s | -0.428422 | 0.004202 | 0.733s | 0.863s |
CatB6 | -2.316594e-01 | -0.291079 | 3.526s | -0.343854 | 0.001604 | 17.254s | 20.780s | |
ET6 | -2.500448e-15 | -0.352708 | 0.610s | -0.392057 | 0.001472 | 2.743s | 3.356s | |
LGB6 | -2.350471e-01 | -0.325999 | 0.180s | -0.360198 | 0.003051 | 0.734s | 0.915s | |
RF6 | -1.491169e-01 | -0.399763 | 1.224s | -0.407387 | 0.003003 | 5.745s | 6.971s | |
Tree6 | -3.743775e-17 | -0.563991 | 0.029s | -0.562617 | 0.019377 | 0.280s | 0.309s | |
0.33 | CatB3 | -2.316594e-01 | -0.291079 | 3.866s | -0.323212 | 0.000260 | 19.343s | 23.209s |
ET3 | -2.500448e-15 | -0.352708 | 1.154s | -0.367404 | 0.001401 | 5.003s | 6.162s | |
LGB3 | -2.350471e-01 | -0.325999 | 0.197s | -0.335340 | 0.001817 | 0.889s | 1.086s | |
1.00 | CatB1 | -2.316594e-01 | -0.291079 | 5.315s | -0.298891 | 0.001028 | 26.528s | 31.843s |
In [6]:
Copied!
# Plot the successive halving's results
atom.plot_successive_halving()
# Plot the successive halving's results
atom.plot_successive_halving()
In [7]:
Copied!
# Use an acronym to call all the models with the same estimator
atom.plot_errors(models=["CatB"])
# Use an acronym to call all the models with the same estimator
atom.plot_errors(models=["CatB"])
In [8]:
Copied!
# Use the number to call the models from the same run
atom.plot_errors(models="3")
# Use the number to call the models from the same run
atom.plot_errors(models="3")