Binary classification¶
This example shows how to use ATOM to solve a binary classification problem. Additonnaly, we'll perform a variety of data cleaning steps to prepare the data for modelling.
The data used is a variation on the Australian weather dataset from Kaggle. You can download it from here. The goal of this dataset is to predict whether or not it will rain tomorrow training a binary classifier on target RainTomorrow
.
Load the data¶
In [1]:
Copied!
# Import packages
import pandas as pd
from atom import ATOMClassifier
# Import packages
import pandas as pd
from atom import ATOMClassifier
In [2]:
Copied!
# Load data
X = pd.read_csv("./datasets/weatherAUS.csv")
# Let's have a look
X.head()
# Load data
X = pd.read_csv("./datasets/weatherAUS.csv")
# Let's have a look
X.head()
Out[2]:
Location | MinTemp | MaxTemp | Rainfall | Evaporation | Sunshine | WindGustDir | WindGustSpeed | WindDir9am | WindDir3pm | ... | Humidity9am | Humidity3pm | Pressure9am | Pressure3pm | Cloud9am | Cloud3pm | Temp9am | Temp3pm | RainToday | RainTomorrow | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | MelbourneAirport | 18.0 | 26.9 | 21.4 | 7.0 | 8.9 | SSE | 41.0 | W | SSE | ... | 95.0 | 54.0 | 1019.5 | 1017.0 | 8.0 | 5.0 | 18.5 | 26.0 | Yes | 0 |
1 | Adelaide | 17.2 | 23.4 | 0.0 | NaN | NaN | S | 41.0 | S | WSW | ... | 59.0 | 36.0 | 1015.7 | 1015.7 | NaN | NaN | 17.7 | 21.9 | No | 0 |
2 | Cairns | 18.6 | 24.6 | 7.4 | 3.0 | 6.1 | SSE | 54.0 | SSE | SE | ... | 78.0 | 57.0 | 1018.7 | 1016.6 | 3.0 | 3.0 | 20.8 | 24.1 | Yes | 0 |
3 | Portland | 13.6 | 16.8 | 4.2 | 1.2 | 0.0 | ESE | 39.0 | ESE | ESE | ... | 76.0 | 74.0 | 1021.4 | 1020.5 | 7.0 | 8.0 | 15.6 | 16.0 | Yes | 1 |
4 | Walpole | 16.4 | 19.9 | 0.0 | NaN | NaN | SE | 44.0 | SE | SE | ... | 78.0 | 70.0 | 1019.4 | 1018.9 | NaN | NaN | 17.4 | 18.1 | No | 0 |
5 rows × 22 columns
Run the pipeline¶
In [3]:
Copied!
# Call atom using only 5% of the complete dataset (for explanatory purposes)
atom = ATOMClassifier(X, "RainTomorrow", n_rows=0.05, n_jobs=8, warnings=False, verbose=2)
# Call atom using only 5% of the complete dataset (for explanatory purposes)
atom = ATOMClassifier(X, "RainTomorrow", n_rows=0.05, n_jobs=8, warnings=False, verbose=2)
<< ================== ATOM ================== >> Algorithm task: binary classification. Parallel processing with 8 cores. Dataset stats ==================== >> Shape: (7109, 22) Memory: 3.08 MB Scaled: False Missing values: 15896 (10.2%) Categorical features: 5 (23.8%) Duplicate samples: 2 (0.0%) ------------------------------------- Train set size: 5688 Test set size: 1421 ------------------------------------- | | dataset | train | test | | - | ------------ | ------------ | ------------ | | 0 | 5614 (3.8) | 4492 (3.8) | 1122 (3.8) | | 1 | 1495 (1.0) | 1196 (1.0) | 299 (1.0) |
In [4]:
Copied!
# Impute missing values
atom.impute(strat_num="median", strat_cat="drop", max_nan_rows=0.8)
# Impute missing values
atom.impute(strat_num="median", strat_cat="drop", max_nan_rows=0.8)
Fitting Imputer... Imputing missing values... --> Dropping 774 samples for containing more than 16 missing values. --> Imputing 7 missing values with median (12.1) in feature MinTemp. --> Imputing 5 missing values with median (22.9) in feature MaxTemp. --> Imputing 33 missing values with median (0.0) in feature Rainfall. --> Imputing 2315 missing values with median (4.8) in feature Evaporation. --> Imputing 2648 missing values with median (8.5) in feature Sunshine. --> Dropping 202 samples due to missing values in feature WindGustDir. --> Imputing 200 missing values with median (39.0) in feature WindGustSpeed. --> Dropping 365 samples due to missing values in feature WindDir9am. --> Dropping 24 samples due to missing values in feature WindDir3pm. --> Imputing 4 missing values with median (13.0) in feature WindSpeed9am. --> Imputing 3 missing values with median (19.0) in feature WindSpeed3pm. --> Imputing 23 missing values with median (69.0) in feature Humidity9am. --> Imputing 57 missing values with median (52.0) in feature Humidity3pm. --> Imputing 42 missing values with median (1017.6) in feature Pressure9am. --> Imputing 40 missing values with median (1015.2) in feature Pressure3pm. --> Imputing 2112 missing values with median (5.0) in feature Cloud9am. --> Imputing 2200 missing values with median (5.0) in feature Cloud3pm. --> Imputing 5 missing values with median (16.9) in feature Temp9am. --> Imputing 34 missing values with median (21.3) in feature Temp3pm. --> Dropping 33 samples due to missing values in feature RainToday.
In [5]:
Copied!
# Encode the categorical features
atom.encode(strategy="Target", max_onehot=10, frac_to_other=0.04)
# Encode the categorical features
atom.encode(strategy="Target", max_onehot=10, frac_to_other=0.04)
Fitting Encoder... Encoding categorical columns... --> Target-encoding feature Location. Contains 44 classes. --> Target-encoding feature WindGustDir. Contains 16 classes. --> Target-encoding feature WindDir9am. Contains 16 classes. --> Target-encoding feature WindDir3pm. Contains 16 classes. --> Ordinal-encoding feature RainToday. Contains 2 classes.
In [6]:
Copied!
# Train an Extra-Trees and a Random Forest model
atom.run(models=["ET", "RF"], metric="f1", n_bootstrap=5)
# Train an Extra-Trees and a Random Forest model
atom.run(models=["ET", "RF"], metric="f1", n_bootstrap=5)
Training ========================= >> Models: ET, RF Metric: f1 Results for Extra-Trees: Fit --------------------------------------------- Train evaluation --> f1: 1.0 Test evaluation --> f1: 0.5823 Time elapsed: 0.194s Bootstrap --------------------------------------- Evaluation --> f1: 0.5593 ± 0.0155 Time elapsed: 0.794s ------------------------------------------------- Total time: 0.991s Results for Random Forest: Fit --------------------------------------------- Train evaluation --> f1: 1.0 Test evaluation --> f1: 0.5985 Time elapsed: 0.244s Bootstrap --------------------------------------- Evaluation --> f1: 0.592 ± 0.0231 Time elapsed: 1.012s ------------------------------------------------- Total time: 1.257s Final results ==================== >> Duration: 2.249s ------------------------------------- Extra-Trees --> f1: 0.5593 ± 0.0155 ~ Random Forest --> f1: 0.592 ± 0.0231 ~ !
Analyze the results¶
In [7]:
Copied!
# Let's have a look at the final results
atom.results
# Let's have a look at the final results
atom.results
Out[7]:
metric_train | metric_test | time_fit | mean_bootstrap | std_bootstrap | time_bootstrap | time | |
---|---|---|---|---|---|---|---|
ET | 1.0 | 0.582278 | 0.194s | 0.559296 | 0.015532 | 0.794s | 0.991s |
RF | 1.0 | 0.598540 | 0.244s | 0.591987 | 0.023073 | 1.012s | 1.257s |
In [8]:
Copied!
# Visualize the bootstrap results
atom.plot_results(title="RF vs ET performance")
# Visualize the bootstrap results
atom.plot_results(title="RF vs ET performance")
In [9]:
Copied!
# Print the results of some common metrics
atom.evaluate()
# Print the results of some common metrics
atom.evaluate()
Out[9]:
accuracy | average_precision | balanced_accuracy | f1 | jaccard | matthews_corrcoef | precision | recall | roc_auc | |
---|---|---|---|---|---|---|---|---|---|
ET | 0.853982 | 0.691655 | 0.714920 | 0.582278 | 0.410714 | 0.522039 | 0.766667 | 0.469388 | 0.863625 |
RF | 0.853982 | 0.694780 | 0.726727 | 0.598540 | 0.427083 | 0.527831 | 0.740964 | 0.502041 | 0.866265 |
In [10]:
Copied!
# The winner attribute calls the best model (atom.winner == atom.rf)
print(f"The winner is the {atom.winner.fullname} model!!")
# The winner attribute calls the best model (atom.winner == atom.rf)
print(f"The winner is the {atom.winner.fullname} model!!")
The winner is the Random Forest model!!
In [11]:
Copied!
# Visualize the distribution of predicted probabilities
atom.winner.plot_probabilities()
# Visualize the distribution of predicted probabilities
atom.winner.plot_probabilities()
In [12]:
Copied!
# Compare how different metrics perform for different thresholds
atom.winner.plot_threshold(metric=["f1", "accuracy", "average_precision"], steps=50)
# Compare how different metrics perform for different thresholds
atom.winner.plot_threshold(metric=["f1", "accuracy", "average_precision"], steps=50)