plot_confusion_matrix
Plot a model's confusion matrix. For one model, the plot shows a heatmap. For multiple models, it compares TP, FP, FN and TN in a barplot (not implemented for multiclass classification tasks). Only for classification tasks.
Parameters: |
models: str, sequence or None, optional (default=None)
dataset: str, optional (default="test")
normalize: bool, optional (default=False)
title: str or None, optional (default=None)
figsize: tuple, optional (default=None)
filename: str or None, optional (default=None)
display: bool or None, optional (default=True) |
Returns: |
fig: matplotlib.figure.Figure Plot object. Only returned if display=None .
|
Example
from atom import ATOMClassifier
atom = ATOMClassifier(X, y)
atom.run(["Tree", "Bag"])
atom.Tree.plot_confusion_matrix(normalize=True) # For one model
atom.plot_confusion_matrix() # For multiple models