Train sizing¶
This example shows how to asses a model's performance based on the size of the training set.
The data used is a variation on the Australian weather dataset from Kaggle. You can download it from here. The goal of this dataset is to predict whether or not it will rain tomorrow training a binary classifier on target RainTomorrow
.
Load the data¶
In [1]:
Copied!
# Import packages
import numpy as np
import pandas as pd
from atom import ATOMClassifier
# Import packages
import numpy as np
import pandas as pd
from atom import ATOMClassifier
In [2]:
Copied!
# Load the data
X = pd.read_csv("./datasets/weatherAUS.csv")
# Let's have a look
X.head()
# Load the data
X = pd.read_csv("./datasets/weatherAUS.csv")
# Let's have a look
X.head()
Out[2]:
Location | MinTemp | MaxTemp | Rainfall | Evaporation | Sunshine | WindGustDir | WindGustSpeed | WindDir9am | WindDir3pm | ... | Humidity9am | Humidity3pm | Pressure9am | Pressure3pm | Cloud9am | Cloud3pm | Temp9am | Temp3pm | RainToday | RainTomorrow | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | MelbourneAirport | 18.0 | 26.9 | 21.4 | 7.0 | 8.9 | SSE | 41.0 | W | SSE | ... | 95.0 | 54.0 | 1019.5 | 1017.0 | 8.0 | 5.0 | 18.5 | 26.0 | Yes | 0 |
1 | Adelaide | 17.2 | 23.4 | 0.0 | NaN | NaN | S | 41.0 | S | WSW | ... | 59.0 | 36.0 | 1015.7 | 1015.7 | NaN | NaN | 17.7 | 21.9 | No | 0 |
2 | Cairns | 18.6 | 24.6 | 7.4 | 3.0 | 6.1 | SSE | 54.0 | SSE | SE | ... | 78.0 | 57.0 | 1018.7 | 1016.6 | 3.0 | 3.0 | 20.8 | 24.1 | Yes | 0 |
3 | Portland | 13.6 | 16.8 | 4.2 | 1.2 | 0.0 | ESE | 39.0 | ESE | ESE | ... | 76.0 | 74.0 | 1021.4 | 1020.5 | 7.0 | 8.0 | 15.6 | 16.0 | Yes | 1 |
4 | Walpole | 16.4 | 19.9 | 0.0 | NaN | NaN | SE | 44.0 | SE | SE | ... | 78.0 | 70.0 | 1019.4 | 1018.9 | NaN | NaN | 17.4 | 18.1 | No | 0 |
5 rows × 22 columns
Run the pipeline¶
In [3]:
Copied!
# Initialize atom and prepare the data
atom = ATOMClassifier(X, verbose=2, random_state=1)
atom.clean()
atom.impute(strat_num="median", strat_cat="most_frequent", max_nan_rows=0.8)
atom.encode()
# Initialize atom and prepare the data
atom = ATOMClassifier(X, verbose=2, random_state=1)
atom.clean()
atom.impute(strat_num="median", strat_cat="most_frequent", max_nan_rows=0.8)
atom.encode()
<< ================== ATOM ================== >> Algorithm task: binary classification. Dataset stats ====================== >> Shape: (142193, 22) Scaled: False Missing values: 316559 (10.1%) Categorical features: 5 (23.8%) Duplicate samples: 45 (0.0%) --------------------------------------- Train set size: 113755 Test set size: 28438 --------------------------------------- | | dataset | train | test | |---:|:-------------|:------------|:------------| | 0 | 110316 (3.5) | 88412 (3.5) | 21904 (3.4) | | 1 | 31877 (1.0) | 25343 (1.0) | 6534 (1.0) | Applying data cleaning... Fitting Imputer... Imputing missing values... --> Dropping 15182 samples for containing less than 80% non-missing values. --> Imputing 100 missing values with median (12.2) in feature MinTemp. --> Imputing 57 missing values with median (22.8) in feature MaxTemp. --> Imputing 640 missing values with median (0.0) in feature Rainfall. --> Imputing 46535 missing values with median (4.8) in feature Evaporation. --> Imputing 53034 missing values with median (8.5) in feature Sunshine. --> Imputing 4381 missing values with most_frequent (W) in feature WindGustDir. --> Imputing 4359 missing values with median (39.0) in feature WindGustSpeed. --> Imputing 6624 missing values with most_frequent (N) in feature WindDir9am. --> Imputing 612 missing values with most_frequent (SE) in feature WindDir3pm. --> Imputing 80 missing values with median (13.0) in feature WindSpeed9am. --> Imputing 49 missing values with median (19.0) in feature WindSpeed3pm. --> Imputing 532 missing values with median (69.0) in feature Humidity9am. --> Imputing 1168 missing values with median (52.0) in feature Humidity3pm. --> Imputing 1028 missing values with median (1017.6) in feature Pressure9am. --> Imputing 972 missing values with median (1015.2) in feature Pressure3pm. --> Imputing 42172 missing values with median (5.0) in feature Cloud9am. --> Imputing 44251 missing values with median (5.0) in feature Cloud3pm. --> Imputing 98 missing values with median (16.8) in feature Temp9am. --> Imputing 702 missing values with median (21.3) in feature Temp3pm. --> Imputing 640 missing values with most_frequent (No) in feature RainToday. Fitting Encoder... Encoding categorical columns... --> LeaveOneOut-encoding feature Location. Contains 45 classes. --> LeaveOneOut-encoding feature WindGustDir. Contains 16 classes. --> LeaveOneOut-encoding feature WindDir9am. Contains 16 classes. --> LeaveOneOut-encoding feature WindDir3pm. Contains 16 classes. --> Ordinal-encoding feature RainToday. Contains 2 classes.
In [4]:
Copied!
# Analyze the impact of the training set's size on a LightGBM model
atom.train_sizing("LGB", train_sizes=10, n_bootstrap=5)
# Analyze the impact of the training set's size on a LightGBM model
atom.train_sizing("LGB", train_sizes=10, n_bootstrap=5)
Training ===================================== >> Models: LGB Metric: f1 Run: 0 ================================ >> Size of training set: 10165 (10%) Size of test set: 25359 Results for LightGBM: Fit --------------------------------------------- Train evaluation --> f1: 0.8093 Test evaluation --> f1: 0.61 Time elapsed: 0.749s Bootstrap --------------------------------------- Evaluation --> f1: 0.6092 ± 0.0055 Time elapsed: 1.927s ------------------------------------------------- Total time: 2.675s Final results ========================= >> Duration: 2.675s ------------------------------------------ LightGBM --> f1: 0.6092 ± 0.0055 ~ Run: 1 ================================ >> Size of training set: 20330 (20%) Size of test set: 25359 Results for LightGBM: Fit --------------------------------------------- Train evaluation --> f1: 0.7328 Test evaluation --> f1: 0.6218 Time elapsed: 0.901s Bootstrap --------------------------------------- Evaluation --> f1: 0.615 ± 0.0027 Time elapsed: 2.385s ------------------------------------------------- Total time: 3.287s Final results ========================= >> Duration: 3.288s ------------------------------------------ LightGBM --> f1: 0.615 ± 0.0027 Run: 2 ================================ >> Size of training set: 30495 (30%) Size of test set: 25359 Results for LightGBM: Fit --------------------------------------------- Train evaluation --> f1: 0.7075 Test evaluation --> f1: 0.6252 Time elapsed: 1.076s Bootstrap --------------------------------------- Evaluation --> f1: 0.6199 ± 0.0019 Time elapsed: 2.927s ------------------------------------------------- Total time: 4.004s Final results ========================= >> Duration: 4.004s ------------------------------------------ LightGBM --> f1: 0.6199 ± 0.0019 Run: 3 ================================ >> Size of training set: 40660 (40%) Size of test set: 25359 Results for LightGBM: Fit --------------------------------------------- Train evaluation --> f1: 0.6939 Test evaluation --> f1: 0.6275 Time elapsed: 1.369s Bootstrap --------------------------------------- Evaluation --> f1: 0.6215 ± 0.002 Time elapsed: 3.555s ------------------------------------------------- Total time: 4.924s Final results ========================= >> Duration: 4.925s ------------------------------------------ LightGBM --> f1: 0.6215 ± 0.002 Run: 4 ================================ >> Size of training set: 50826 (50%) Size of test set: 25359 Results for LightGBM: Fit --------------------------------------------- Train evaluation --> f1: 0.6814 Test evaluation --> f1: 0.6291 Time elapsed: 1.590s Bootstrap --------------------------------------- Evaluation --> f1: 0.623 ± 0.0014 Time elapsed: 4.344s ------------------------------------------------- Total time: 5.935s Final results ========================= >> Duration: 5.935s ------------------------------------------ LightGBM --> f1: 0.623 ± 0.0014 Run: 5 ================================ >> Size of training set: 60991 (60%) Size of test set: 25359 Results for LightGBM: Fit --------------------------------------------- Train evaluation --> f1: 0.6766 Test evaluation --> f1: 0.6356 Time elapsed: 1.813s Bootstrap --------------------------------------- Evaluation --> f1: 0.6283 ± 0.0022 Time elapsed: 4.853s ------------------------------------------------- Total time: 6.667s Final results ========================= >> Duration: 6.667s ------------------------------------------ LightGBM --> f1: 0.6283 ± 0.0022 Run: 6 ================================ >> Size of training set: 71156 (70%) Size of test set: 25359 Results for LightGBM: Fit --------------------------------------------- Train evaluation --> f1: 0.6742 Test evaluation --> f1: 0.6289 Time elapsed: 2.045s Bootstrap --------------------------------------- Evaluation --> f1: 0.6297 ± 0.0024 Time elapsed: 5.420s ------------------------------------------------- Total time: 7.466s Final results ========================= >> Duration: 7.466s ------------------------------------------ LightGBM --> f1: 0.6297 ± 0.0024 Run: 7 ================================ >> Size of training set: 81321 (80%) Size of test set: 25359 Results for LightGBM: Fit --------------------------------------------- Train evaluation --> f1: 0.672 Test evaluation --> f1: 0.6322 Time elapsed: 2.165s Bootstrap --------------------------------------- Evaluation --> f1: 0.6311 ± 0.0023 Time elapsed: 5.971s ------------------------------------------------- Total time: 8.138s Final results ========================= >> Duration: 8.138s ------------------------------------------ LightGBM --> f1: 0.6311 ± 0.0023 Run: 8 ================================ >> Size of training set: 91486 (90%) Size of test set: 25359 Results for LightGBM: Fit --------------------------------------------- Train evaluation --> f1: 0.6674 Test evaluation --> f1: 0.6354 Time elapsed: 2.503s Bootstrap --------------------------------------- Evaluation --> f1: 0.6321 ± 0.0039 Time elapsed: 6.606s ------------------------------------------------- Total time: 9.111s Final results ========================= >> Duration: 9.111s ------------------------------------------ LightGBM --> f1: 0.6321 ± 0.0039 Run: 9 ================================ >> Size of training set: 101652 (100%) Size of test set: 25359 Results for LightGBM: Fit --------------------------------------------- Train evaluation --> f1: 0.665 Test evaluation --> f1: 0.6356 Time elapsed: 2.612s Bootstrap --------------------------------------- Evaluation --> f1: 0.6337 ± 0.0038 Time elapsed: 7.193s ------------------------------------------------- Total time: 9.806s Final results ========================= >> Duration: 9.806s ------------------------------------------ LightGBM --> f1: 0.6337 ± 0.0038
Analyze the results¶
In [5]:
Copied!
# The results is now multi-index, where frac is the fraction
# of the training set used to fit the model. The model names
# end with the fraction as well (without the dot)
atom.results
# The results is now multi-index, where frac is the fraction
# of the training set used to fit the model. The model names
# end with the fraction as well (without the dot)
atom.results
Out[5]:
metric_train | metric_test | time_fit | mean_bootstrap | std_bootstrap | time_bootstrap | time | ||
---|---|---|---|---|---|---|---|---|
frac | model | |||||||
0.1 | LGB01 | 0.809309 | 0.610019 | 0.749s | 0.609161 | 0.005489 | 1.927s | 2.675s |
0.2 | LGB02 | 0.732799 | 0.621756 | 0.901s | 0.614965 | 0.002726 | 2.385s | 3.287s |
0.3 | LGB03 | 0.707509 | 0.625230 | 1.076s | 0.619943 | 0.001864 | 2.927s | 4.004s |
0.4 | LGB04 | 0.693875 | 0.627527 | 1.369s | 0.621488 | 0.001970 | 3.555s | 4.924s |
0.5 | LGB05 | 0.681449 | 0.629054 | 1.590s | 0.623023 | 0.001408 | 4.344s | 5.935s |
0.6 | LGB06 | 0.676566 | 0.635607 | 1.813s | 0.628340 | 0.002204 | 4.853s | 6.667s |
0.7 | LGB07 | 0.674170 | 0.628914 | 2.045s | 0.629710 | 0.002405 | 5.420s | 7.466s |
0.8 | LGB08 | 0.672033 | 0.632210 | 2.165s | 0.631072 | 0.002290 | 5.971s | 8.138s |
0.9 | LGB09 | 0.667443 | 0.635430 | 2.503s | 0.632102 | 0.003885 | 6.606s | 9.111s |
1.0 | LGB10 | 0.665031 | 0.635616 | 2.612s | 0.633657 | 0.003810 | 7.193s | 9.806s |
In [6]:
Copied!
# Every model can be accessed through its name
atom.lgb05.waterfall_plot(show=7)
# Every model can be accessed through its name
atom.lgb05.waterfall_plot(show=7)
In [7]:
Copied!
# Plot the train sizing's results
atom.plot_learning_curve()
# Plot the train sizing's results
atom.plot_learning_curve()