Skip to content

plot_confusion_matrix


method plot_confusion_matrix(models=None, dataset="test", normalize=False, title=None, figsize=None, filename=None, display=True) [source]

Plot a model's confusion matrix. For one model, the plot shows a heatmap. For multiple models, it compares TP, FP, FN and TN in a barplot (not implemented for multiclass classification tasks). Only for classification tasks.

Parameters:

models: str, sequence or None, optional (default=None)
Name of the models to plot. If None, all models in the pipeline are selected.

dataset: str, optional (default="test")
Data set on which to calculate the confusion matrix. Choose from: "train", "test" or "holdout".

normalize: bool, optional (default=False)
Whether to normalize the matrix.

title: str or None, optional (default=None)
Plot's title. If None, the title is left empty.

figsize: tuple, optional (default=None)
Figure's size, format as (x, y). If None, it adapts the size to plot's type.

filename: str or None, optional (default=None)
Name of the file. Use "auto" for automatic naming. If None, the figure is not saved.

display: bool or None, optional (default=True)
Whether to render the plot. If None, it returns the matplotlib figure.

Returns: matplotlib.figure.Figure
Plot object. Only returned if display=None.


Example

from atom import ATOMClassifier

atom = ATOMClassifier(X, y)
atom.run(["Tree", "Bag"])
atom.Tree.plot_confusion_matrix(normalize=True)  # For one model
plot_confusion_matrix_1
atom.plot_confusion_matrix()  # For multiple models
plot_confusion_matrix_2
Back to top