Skip to content

plot_confusion_matrix


method plot_confusion_matrix(models=None, rows="test", target=0, threshold=0.5, title=None, legend="upper right", figsize=None, filename=None, display=True)[source]

Plot a model's confusion matrix.

For one model, the plot shows a heatmap. For multiple models, it compares TP, FP, FN and TN in a barplot (not implemented for multiclass classification tasks). This plot is available only for classification tasks.

Tip

Fill the threshold parameter with the result from the model's get_best_threshold method to optimize the results.

Parameters models: int, str, Model, segment, sequence or None, default=None
Models to plot. If None, all models are selected.

rows: hashable, segment or sequence, default="test"
Selection of rows on which to calculate the confusion matrix.

target: int or str, default=0
Target column to look at. Only for multioutput tasks.

threshold: float, default=0.5
Threshold between 0 and 1 to convert predicted probabilities to class labels. Only for binary classification tasks.

title: str, dict or None, default=None
Title for the plot.

legend: str, dict or None, default="upper right"
Legend for the plot. See the user guide for an extended description of the choices.

  • If None: No legend is shown.
  • If str: Position to display the legend.
  • If dict: Legend configuration.

figsize: tuple or None, default=None
Figure's size in pixels, format as (x, y). If None, it adapts the size to the plot's type.

filename: str, Path or None, default=None
Save the plot using this name. Use "auto" for automatic naming. The type of the file depends on the provided name (.html, .png, .pdf, etc...). If filename has no file type, the plot is saved as html. If None, the plot is not saved.

display: bool or None, default=True
Whether to render the plot. If None, it returns the figure.

Returns{#plot_confusion_matrix-go.Figure or None} go.Figure or None
Plot object. Only returned if display=None.


See Also

plot_calibration

Plot the calibration curve for a binary classifier.

plot_threshold

Plot metric performances against threshold values.


Example

>>> from atom import ATOMClassifier
>>> from sklearn.datasets import make_classification

>>> X, y = make_classification(n_samples=100, flip_y=0.2, random_state=1)

>>> atom = ATOMClassifier(X, y, test_size=0.4)
>>> atom.run(["LR", "RF"])
>>> atom.lr.plot_confusion_matrix()  # For one model

>>> atom.plot_confusion_matrix()  # For multiple models