Example: Imbalanced datasets¶
This example shows how ATOM can help you handle imbalanced datasets. We will evaluate the performance of three different Random Forest models: one trained directly on the imbalanced dataset, one trained on an oversampled dataset and the last one trained on an undersampled dataset.
Load the data¶
In [1]:
Copied!
# Import packages
from atom import ATOMClassifier
from sklearn.datasets import make_classification
# Import packages
from atom import ATOMClassifier
from sklearn.datasets import make_classification
UserWarning: The pandas version installed (1.5.3) does not match the supported pandas version in Modin (1.5.2). This may cause undesired side effects!
In [2]:
Copied!
# Create a mock imbalanced dataset
X, y = make_classification(
n_samples=5000,
n_features=30,
n_informative=20,
weights=(0.95,),
random_state=1,
)
# Create a mock imbalanced dataset
X, y = make_classification(
n_samples=5000,
n_features=30,
n_informative=20,
weights=(0.95,),
random_state=1,
)
Run the pipeline¶
In [3]:
Copied!
# Initialize atom
atom = ATOMClassifier(X, y, test_size=0.2, verbose=2, random_state=1)
# Initialize atom
atom = ATOMClassifier(X, y, test_size=0.2, verbose=2, random_state=1)
<< ================== ATOM ================== >> Algorithm task: binary classification. Dataset stats ==================== >> Shape: (5000, 31) Train set size: 4000 Test set size: 1000 ------------------------------------- Memory: 1.24 MB Scaled: False Outlier values: 570 (0.5%)
In [4]:
Copied!
# Let's have a look at the data. Note that, since the input wasn't
# a dataframe, atom has given default names to the columns.
atom.head()
# Let's have a look at the data. Note that, since the input wasn't
# a dataframe, atom has given default names to the columns.
atom.head()
Out[4]:
x0 | x1 | x2 | x3 | x4 | x5 | x6 | x7 | x8 | x9 | ... | x21 | x22 | x23 | x24 | x25 | x26 | x27 | x28 | x29 | target | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | -0.535760 | -2.426045 | 1.256836 | 0.374501 | -3.241958 | -1.239468 | -0.208750 | -6.015995 | 3.698669 | 0.112512 | ... | 0.044302 | -1.935727 | 10.870353 | 0.286755 | -2.416507 | 0.556990 | -1.522635 | 3.719201 | 1.449135 | 0 |
1 | -3.311935 | -3.149920 | -0.801252 | -2.644414 | -0.704889 | -3.312256 | 0.714515 | 2.992345 | 5.056910 | 3.036775 | ... | 2.224359 | 0.451273 | -1.822108 | -1.435801 | 0.036132 | -1.364583 | 1.215663 | 5.232161 | 1.408798 | 0 |
2 | 3.821199 | 1.328129 | -1.000720 | -13.151697 | 0.254253 | 1.263636 | -1.088451 | 4.924264 | -1.225646 | -6.974824 | ... | 3.541222 | 1.686667 | -13.763703 | -1.321256 | 1.677687 | 0.774966 | -5.067689 | 4.663386 | -1.714186 | 0 |
3 | 5.931126 | 3.338830 | 0.545906 | 2.296355 | -3.941088 | 3.527252 | -0.158770 | 3.138381 | -0.927460 | -1.642079 | ... | -3.634442 | 7.853176 | -8.457598 | 0.000490 | -2.612756 | -1.138206 | 0.497150 | 4.351289 | -0.321748 | 0 |
4 | -2.829472 | -1.227185 | -0.751892 | 3.056106 | -1.988920 | -2.219184 | -0.075882 | 5.790102 | -2.786671 | 2.023458 | ... | 4.057954 | 1.178564 | -15.028187 | 1.627140 | -1.093587 | -0.422655 | 1.777011 | 6.660638 | -2.553723 | 0 |
5 rows × 31 columns
In [5]:
Copied!
# Let's start reducing the number of features
atom.feature_selection("RFE", solver="RF", n_features=12)
# Let's start reducing the number of features
atom.feature_selection("RFE", solver="RF", n_features=12)
Fitting FeatureSelector... Performing feature selection ... --> rfe selected 12 features from the dataset. --> Dropping feature x1 (rank 8). --> Dropping feature x2 (rank 11). --> Dropping feature x4 (rank 3). --> Dropping feature x6 (rank 16). --> Dropping feature x7 (rank 14). --> Dropping feature x10 (rank 19). --> Dropping feature x12 (rank 13). --> Dropping feature x13 (rank 12). --> Dropping feature x14 (rank 9). --> Dropping feature x16 (rank 10). --> Dropping feature x18 (rank 17). --> Dropping feature x19 (rank 2). --> Dropping feature x20 (rank 4). --> Dropping feature x22 (rank 7). --> Dropping feature x23 (rank 5). --> Dropping feature x24 (rank 18). --> Dropping feature x25 (rank 6). --> Dropping feature x26 (rank 15).
In [6]:
Copied!
# Fit a model directly on the imbalanced data
atom.run("RF", metric="ba")
# Fit a model directly on the imbalanced data
atom.run("RF", metric="ba")
Training ========================= >> Models: RF Metric: balanced_accuracy Results for RandomForest: Fit --------------------------------------------- Train evaluation --> balanced_accuracy: 1.0 Test evaluation --> balanced_accuracy: 0.5556 Time elapsed: 1.266s ------------------------------------------------- Total time: 1.266s Final results ==================== >> Total time: 1.268s ------------------------------------- RandomForest --> balanced_accuracy: 0.5556 ~
In [7]:
Copied!
# The transformer and the models have been added to the branch
atom.branch
# The transformer and the models have been added to the branch
atom.branch
Out[7]:
Branch(master)
Oversampling¶
In [8]:
Copied!
# Create a new branch for oversampling
atom.branch = "oversample"
# Create a new branch for oversampling
atom.branch = "oversample"
New branch oversample successfully created.
In [9]:
Copied!
# Perform oversampling of the minority class
atom.balance(strategy="smote")
# Perform oversampling of the minority class
atom.balance(strategy="smote")
Oversampling with SMOTE... --> Adding 3570 samples to class 1.
In [10]:
Copied!
atom.classes # Check the balanced training set!
atom.classes # Check the balanced training set!
Out[10]:
dataset | train | test | |
---|---|---|---|
0 | 4731 | 3785 | 946 |
1 | 3839 | 3785 | 54 |
In [11]:
Copied!
# Train another model on the new branch. Add a tag after
# the model's acronym to distinguish it from the first model
atom.run("rf_os") # os for oversample
# Train another model on the new branch. Add a tag after
# the model's acronym to distinguish it from the first model
atom.run("rf_os") # os for oversample
Training ========================= >> Models: RF_os Metric: balanced_accuracy Results for RandomForest: Fit --------------------------------------------- Train evaluation --> balanced_accuracy: 1.0 Test evaluation --> balanced_accuracy: 0.7672 Time elapsed: 2.286s ------------------------------------------------- Total time: 2.286s Final results ==================== >> Total time: 2.288s ------------------------------------- RandomForest --> balanced_accuracy: 0.7672 ~
Undersampling¶
In [12]:
Copied!
# Create the undersampling branch
# Split from master to not adopt the oversmapling transformer
atom.branch = "undersample_from_master"
# Create the undersampling branch
# Split from master to not adopt the oversmapling transformer
atom.branch = "undersample_from_master"
New branch undersample successfully created.
In [13]:
Copied!
atom.classes # In this branch, the data is still imbalanced
atom.classes # In this branch, the data is still imbalanced
Out[13]:
dataset | train | test | |
---|---|---|---|
0 | 4731 | 3785 | 946 |
1 | 269 | 215 | 54 |
In [14]:
Copied!
# Perform undersampling of the majority class
atom.balance(strategy="NearMiss")
# Perform undersampling of the majority class
atom.balance(strategy="NearMiss")
Undersampling with NearMiss... --> Removing 3570 samples from class 0.
In [15]:
Copied!
atom.run("rf_us")
atom.run("rf_us")
Training ========================= >> Models: RF_us Metric: balanced_accuracy Results for RandomForest: Fit --------------------------------------------- Train evaluation --> balanced_accuracy: 1.0 Test evaluation --> balanced_accuracy: 0.6706 Time elapsed: 0.211s ------------------------------------------------- Total time: 0.211s Final results ==================== >> Total time: 0.212s ------------------------------------- RandomForest --> balanced_accuracy: 0.6706 ~
In [16]:
Copied!
# Check that the branch only contains the desired transformers
atom.branch
# Check that the branch only contains the desired transformers
atom.branch
Out[16]:
Branch(undersample)
In [17]:
Copied!
# Visualize the complete pipeline
atom.plot_pipeline()
# Visualize the complete pipeline
atom.plot_pipeline()
Analyze the results¶
In [18]:
Copied!
atom.evaluate()
atom.evaluate()
Out[18]:
accuracy | average_precision | balanced_accuracy | f1 | jaccard | matthews_corrcoef | precision | recall | roc_auc | |
---|---|---|---|---|---|---|---|---|---|
RF | 0.952 | 0.6562 | 0.5556 | 0.2000 | 0.1111 | 0.3252 | 1.000 | 0.1111 | 0.9107 |
RF_os | 0.956 | 0.6215 | 0.7672 | 0.5769 | 0.4054 | 0.5542 | 0.600 | 0.5556 | 0.9251 |
RF_us | 0.509 | 0.3687 | 0.6706 | 0.1578 | 0.0857 | 0.1545 | 0.087 | 0.8519 | 0.8258 |
In [19]:
Copied!
atom.plot_prc()
atom.plot_prc()
In [20]:
Copied!
atom.plot_roc()
atom.plot_roc()