Example: Holdout set¶
This example shows when and how to use ATOM's holdout set in an exploration pipeline.
The data used is a variation on the Australian weather dataset from Kaggle. You can download it from here. The goal of this dataset is to predict whether or not it will rain tomorrow training a binary classifier on target RainTomorrow
.
Load the data¶
# Import packages
import pandas as pd
from atom import ATOMClassifier
UserWarning: The pandas version installed (1.5.3) does not match the supported pandas version in Modin (1.5.2). This may cause undesired side effects!
# Load data
X = pd.read_csv("./datasets/weatherAUS.csv")
# Let's have a look
X.head()
Location | MinTemp | MaxTemp | Rainfall | Evaporation | Sunshine | WindGustDir | WindGustSpeed | WindDir9am | WindDir3pm | ... | Humidity9am | Humidity3pm | Pressure9am | Pressure3pm | Cloud9am | Cloud3pm | Temp9am | Temp3pm | RainToday | RainTomorrow | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | MelbourneAirport | 18.0 | 26.9 | 21.4 | 7.0 | 8.9 | SSE | 41.0 | W | SSE | ... | 95.0 | 54.0 | 1019.5 | 1017.0 | 8.0 | 5.0 | 18.5 | 26.0 | Yes | 0 |
1 | Adelaide | 17.2 | 23.4 | 0.0 | NaN | NaN | S | 41.0 | S | WSW | ... | 59.0 | 36.0 | 1015.7 | 1015.7 | NaN | NaN | 17.7 | 21.9 | No | 0 |
2 | Cairns | 18.6 | 24.6 | 7.4 | 3.0 | 6.1 | SSE | 54.0 | SSE | SE | ... | 78.0 | 57.0 | 1018.7 | 1016.6 | 3.0 | 3.0 | 20.8 | 24.1 | Yes | 0 |
3 | Portland | 13.6 | 16.8 | 4.2 | 1.2 | 0.0 | ESE | 39.0 | ESE | ESE | ... | 76.0 | 74.0 | 1021.4 | 1020.5 | 7.0 | 8.0 | 15.6 | 16.0 | Yes | 1 |
4 | Walpole | 16.4 | 19.9 | 0.0 | NaN | NaN | SE | 44.0 | SE | SE | ... | 78.0 | 70.0 | 1019.4 | 1018.9 | NaN | NaN | 17.4 | 18.1 | No | 0 |
5 rows × 22 columns
Run the pipeline¶
# Initialize atom specifying a fraction of the dataset for holdout
atom = ATOMClassifier(X, n_rows=0.5, holdout_size=0.2, verbose=2)
<< ================== ATOM ================== >> Algorithm task: binary classification. Dataset stats ==================== >> Shape: (56877, 22) Train set size: 42658 Test set size: 14219 Holdout set size: 14219 ------------------------------------- Memory: 24.68 MB Scaled: False Missing values: 126457 (10.1%) Categorical features: 5 (23.8%) Duplicate samples: 15 (0.0%)
# The test and holdout fractions are split after subsampling the dataset
# Also note that the holdout data set is not a part of atom's dataset
print("Length loaded data:", len(X))
print("Length dataset + holdout:", len(atom.dataset) + len(atom.holdout))
Length loaded data: 142193 Length dataset + holdout: 71096
atom.impute()
atom.encode()
Fitting Imputer... Imputing missing values... --> Dropping 246 samples due to missing values in feature MinTemp. --> Dropping 136 samples due to missing values in feature MaxTemp. --> Dropping 544 samples due to missing values in feature Rainfall. --> Dropping 24464 samples due to missing values in feature Evaporation. --> Dropping 27155 samples due to missing values in feature Sunshine. --> Dropping 3642 samples due to missing values in feature WindGustDir. --> Dropping 3616 samples due to missing values in feature WindGustSpeed. --> Dropping 4027 samples due to missing values in feature WindDir9am. --> Dropping 1491 samples due to missing values in feature WindDir3pm. --> Dropping 536 samples due to missing values in feature WindSpeed9am. --> Dropping 1035 samples due to missing values in feature WindSpeed3pm. --> Dropping 690 samples due to missing values in feature Humidity9am. --> Dropping 1419 samples due to missing values in feature Humidity3pm. --> Dropping 5566 samples due to missing values in feature Pressure9am. --> Dropping 5548 samples due to missing values in feature Pressure3pm. --> Dropping 21527 samples due to missing values in feature Cloud9am. --> Dropping 22832 samples due to missing values in feature Cloud3pm. --> Dropping 373 samples due to missing values in feature Temp9am. --> Dropping 1066 samples due to missing values in feature Temp3pm. --> Dropping 544 samples due to missing values in feature RainToday. Fitting Encoder... Encoding categorical columns... --> LeaveOneOut-encoding feature Location. Contains 26 classes. --> LeaveOneOut-encoding feature WindGustDir. Contains 16 classes. --> LeaveOneOut-encoding feature WindDir9am. Contains 16 classes. --> LeaveOneOut-encoding feature WindDir3pm. Contains 16 classes. --> Ordinal-encoding feature RainToday. Contains 2 classes.
# Unlike train and test, the holdout data set is not transformed until used for predictions
atom.holdout
Location | MinTemp | MaxTemp | Rainfall | Evaporation | Sunshine | WindGustDir | WindGustSpeed | WindDir9am | WindDir3pm | ... | Humidity9am | Humidity3pm | Pressure9am | Pressure3pm | Cloud9am | Cloud3pm | Temp9am | Temp3pm | RainToday | RainTomorrow | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | PearceRAAF | 13.3 | 25.6 | 0.0 | NaN | 5.8 | WNW | 52.0 | WNW | WNW | ... | 58.0 | 51.0 | 1010.8 | 1009.6 | 7.0 | 7.0 | 21.6 | 24.8 | No | 1 |
1 | Darwin | 25.2 | 30.2 | 4.2 | 3.4 | 2.1 | WNW | 76.0 | WNW | WNW | ... | 83.0 | 85.0 | 1005.9 | 1003.4 | 7.0 | 7.0 | 28.6 | 27.4 | Yes | 1 |
2 | Portland | 8.1 | 15.1 | 4.4 | 3.8 | 7.5 | W | 52.0 | SW | SSW | ... | 66.0 | 60.0 | 1013.6 | 1017.4 | 8.0 | 7.0 | 13.4 | 13.2 | Yes | 0 |
3 | Perth | 10.5 | 22.7 | 0.0 | 2.4 | 9.2 | WNW | 26.0 | NNE | NW | ... | 86.0 | 68.0 | 1016.2 | 1014.3 | 1.0 | 3.0 | 15.8 | 20.8 | No | 1 |
4 | MountGinini | 14.7 | 24.1 | 102.2 | NaN | NaN | NW | 52.0 | SW | W | ... | 100.0 | 78.0 | NaN | NaN | NaN | NaN | 16.5 | 21.0 | Yes | 0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
14214 | Walpole | 10.3 | 16.1 | 2.4 | NaN | NaN | NW | 20.0 | NW | SSE | ... | 97.0 | 78.0 | 1024.4 | 1022.9 | NaN | NaN | 12.3 | 14.7 | Yes | 0 |
14215 | AliceSprings | 24.8 | 29.0 | 0.2 | 9.4 | 0.0 | NNW | 31.0 | NNE | SE | ... | 69.0 | 56.0 | 1009.4 | 1007.1 | 8.0 | 8.0 | 26.1 | 27.1 | No | 0 |
14216 | Darwin | 26.2 | 34.4 | 0.0 | 4.4 | 9.0 | NW | 33.0 | E | WNW | ... | 69.0 | 65.0 | 1009.5 | 1006.0 | 3.0 | 6.0 | 30.4 | 32.0 | No | 0 |
14217 | Cairns | 23.9 | 28.4 | 2.6 | 2.6 | 0.6 | ENE | 28.0 | NE | NE | ... | 72.0 | 78.0 | 1014.4 | 1011.8 | 7.0 | 8.0 | 27.5 | 27.0 | Yes | 1 |
14218 | Tuggeranong | 9.5 | 23.6 | 0.0 | NaN | NaN | W | 30.0 | NaN | W | ... | 68.0 | 40.0 | 1010.9 | 1008.4 | NaN | NaN | 14.9 | 22.6 | No | 1 |
14219 rows × 22 columns
atom.run(models=["GNB", "LR", "RF"])
Training ========================= >> Models: GNB, LR, RF Metric: f1 Results for GaussianNB: Fit --------------------------------------------- Train evaluation --> f1: 0.6073 Test evaluation --> f1: 0.6164 Time elapsed: 0.044s ------------------------------------------------- Total time: 0.044s Results for LogisticRegression: Fit --------------------------------------------- Train evaluation --> f1: 0.6261 Test evaluation --> f1: 0.6333 Time elapsed: 0.134s ------------------------------------------------- Total time: 0.134s Results for RandomForest: Fit --------------------------------------------- Train evaluation --> f1: 0.9999 Test evaluation --> f1: 0.6163 Time elapsed: 3.344s ------------------------------------------------- Total time: 3.344s Final results ==================== >> Total time: 3.527s ------------------------------------- GaussianNB --> f1: 0.6164 LogisticRegression --> f1: 0.6333 ! RandomForest --> f1: 0.6163 ~
atom.plot_prc()
# Based on the results on the test set, we select the best model for further tuning
atom.run("lr_tuned", n_trials=10)
Training ========================= >> Models: LR_tuned Metric: f1 Running hyperparameter tuning for LogisticRegression... | trial | penalty | C | solver | max_iter | l1_ratio | f1 | best_f1 | time_trial | time_ht | state | | ----- | ------- | ------- | ------- | -------- | -------- | ------- | ------- | ---------- | ------- | -------- | | 0 | None | --- | newto.. | 990 | --- | 0.6325 | 0.6325 | 0.623s | 0.623s | COMPLETE | | 1 | l2 | 0.0096 | sag | 330 | --- | 0.6076 | 0.6325 | 0.582s | 1.205s | COMPLETE | | 2 | None | --- | newto.. | 490 | --- | 0.6099 | 0.6325 | 0.567s | 1.772s | COMPLETE | | 3 | l2 | 0.0023 | newto.. | 940 | --- | 0.5731 | 0.6325 | 0.531s | 2.302s | COMPLETE | | 4 | l2 | 17.8193 | lbfgs | 860 | --- | 0.6297 | 0.6325 | 0.504s | 2.806s | COMPLETE | | 5 | l2 | 20.6114 | libli.. | 650 | --- | 0.6346 | 0.6346 | 0.516s | 3.323s | COMPLETE | | 6 | l2 | 74.6411 | sag | 810 | --- | 0.6269 | 0.6346 | 0.592s | 3.914s | COMPLETE | | 7 | l2 | 36.7273 | lbfgs | 290 | --- | 0.6361 | 0.6361 | 0.546s | 4.461s | COMPLETE | | 8 | None | --- | newto.. | 480 | --- | 0.6328 | 0.6361 | 0.582s | 5.042s | COMPLETE | | 9 | l2 | 0.02 | lbfgs | 760 | --- | 0.6067 | 0.6361 | 0.506s | 5.549s | COMPLETE | Hyperparameter tuning --------------------------- Best trial --> 7 Best parameters: --> penalty: l2 --> C: 36.7273 --> solver: lbfgs --> max_iter: 290 Best evaluation --> f1: 0.6361 Time elapsed: 5.549s Fit --------------------------------------------- Train evaluation --> f1: 0.6264 Test evaluation --> f1: 0.6343 Time elapsed: 0.137s ------------------------------------------------- Total time: 5.686s Final results ==================== >> Total time: 5.713s ------------------------------------- LogisticRegression --> f1: 0.6343
Analyze the results¶
We already used the test set to choose the best model for futher tuning, so this set is no longer truly independent. Although it may not be directly visible in the results, using the test set now to evaluate the tuned LR model would be a mistake, since it carries a bias. For this reason, we have set apart an extra, indepedent set to validate the final model: the holdout set. If we are not going to use the test set for validation, we might as well use it to train the model and so optimize the use of the available data. Use the full_train method for this.
# Re-train the model on the full dataset (train + test)
atom.lr_tuned.full_train()
Fit --------------------------------------------- Train evaluation --> f1: 0.626 Test evaluation --> f1: 0.6355 Time elapsed: 0.270s
# Evaluate on the holdout set
atom.lr_tuned.evaluate(dataset="holdout")
accuracy 0.8550 average_precision 0.7310 balanced_accuracy 0.7396 f1 0.6185 jaccard 0.4477 matthews_corrcoef 0.5422 precision 0.7358 recall 0.5334 roc_auc 0.8918 Name: LR_tuned, dtype: float64
atom.lr_tuned.plot_prc(dataset="holdout", legend="upper right")