plot_threshold
method plot_threshold(models=None, metric=None, dataset="test", target=0, steps=100, title=None, legend="lower left", figsize=(900, 600), filename=None, display=True)[source]
Plot metric performances against threshold values.
This plot is available only for models with a predict_proba
method in a binary or multilabel classification task.
Parameters | models: int, str, Model, slice, sequence or None, default=None
Models to plot. If None, all models are selected.
metric: str, func, scorer, sequence or None, default=None
Metric to plot. Choose from any of sklearn's scorers, a
function with signature dataset: str, default="test"metric(y_true, y_pred) , a scorer
object or a sequence of these. Use a sequence or add +
between options to select more than one. If None, the
metric used to run the pipeline is selected.
Data set on which to calculate the metric. Choose from:
"train", "test" or "holdout".
target: int or str, default=0
Target column to look at. Only for multilabel tasks.
steps: int, default=100
Number of thresholds measured.
title: str, dict or None, default=None
Title for the plot.
legend: str, dict or None, default="lower left"
Legend for the plot. See the user guide for
an extended description of the choices.
figsize: tuple, default=(900, 600)
Figure's size in pixels, format as (x, y).
filename: str or None, default=None
Save the plot using this name. Use "auto" for automatic
naming. The type of the file depends on the provided name
(.html, .png, .pdf, etc...). If display: bool or None, default=Truefilename has no file type,
the plot is saved as html. If None, the plot is not saved.
Whether to render the plot. If None, it returns the figure.
|
Returns | go.Figure or None
Plot object. Only returned if display=None .
|
See Also
Plot the calibration curve for a binary classifier.
Plot a model's confusion matrix.
Plot the probability distribution of the target classes.
Example
>>> from atom import ATOMClassifier
>>> import pandas as pd
>>> X = pd.read_csv("./examples/datasets/weatherAUS.csv")
>>> atom = ATOMClassifier(X, y="RainTomorrow", n_rows=1e4)
>>> atom.impute()
>>> atom.encode()
>>> atom.run(["LR", "RF"])
>>> atom.plot_threshold()