plot_confusion_matrix
Plot a model's confusion matrix. For one model, the plot shows a heatmap. For multiple models, it compares TP, FP, FN and TN in a barplot (not implemented for multiclass classification tasks). Only for classification tasks.
Parameters: |
models: str, sequence or None, optional (default=None)
dataset: str, optional (default="test")
normalize: bool, optional (default=False)
title: str or None, optional (default=None)
figsize: tuple, optional (default=None)
filename: str or None, optional (default=None)
display: bool or None, optional (default=True) |
Returns: |
fig: matplotlib.figure.Figure Plot object. Only returned if display=None .
|
Example
from atom import ATOMClassifier
atom = ATOMClassifier(X, y)
atom.run(["Tree", "Bag"])
atom.Tree.plot_confusion_matrix(normalize=True) # For one model
data:image/s3,"s3://crabby-images/25f81/25f818fb3ff682edd86d94147ecf468721565cd9" alt="plot_confusion_matrix_1"
atom.plot_confusion_matrix() # For multiple models
data:image/s3,"s3://crabby-images/1f147/1f14785f9f2559d26f42d7f602da018550565d0c" alt="plot_confusion_matrix_2"