plot_cv_splits
method plot_cv_splits(models=None, title=None, legend="upper right", figsize=(900, 600), filename=None, display=True)[source]
Visualize the cross-validation splits.
Plots the train and test splits for each cross-validation
iteration of the cv
object. The x-axis shows the number of
rows, where every point corresponds to the n-th sample. The top
bar shows the original train/test split. Additionally, class
labels and groups are plotted when relevant.
Warning
This plot is only available for models that ran cross-validation using the cross_validate method.
Parameters |
models: int, str, Model, segment, sequence or None, default=None
Model to plot. If None, all models are selected. Note that
leaving the default option could raise an exception if there
are multiple models. To avoid this, call the plot directly
from a model, e.g.,
title: str, dict or None, default=Noneatom.lr.plot_cv_splits() .
Title for the plot.
legend: str, dict or None, default="upper right"
Legend for the plot. See the user guide for
an extended description of the choices.
figsize: tuple, default=(900, 600)
Figure's size in pixels, format as (x, y).
filename: str, Path or None, default=None
Save the plot using this name. Use "auto" for automatic
naming. The type of the file depends on the provided name
(.html, .png, .pdf, etc...). If
display: bool or None, default=Truefilename has no file type,
the plot is saved as html. If None, the plot is not saved.
Whether to render the plot. If None, it returns the figure.
|
Returns | {#plot_cv_splits-go.Figure or None}
go.Figure or None
Plot object. Only returned if display=None .
|
See Also
Visualize the data splits.
Plot the trend, seasonality and residuals of a time series.
Plot pairwise relationships in a dataset.
Example
>>> from atom import ATOMClassifier, ATOMForecaster
>>> from random import choices
>>> from sklearn.datasets import load_breast_cancer
>>> from sktime.datasets import load_airline
>>> X, y = load_breast_cancer(return_X_y=True, as_frame=True)
>>> # Without groups
>>> atom = ATOMClassifier(X, y, shuffle=False, n_rows=0.2, random_state=1)
>>> atom.run("LR", metric=["f1", "auc"])
>>> atom.lr.cross_validate(cv=4)
train_f1 | test_f1 | train_auc | test_auc | time | |
---|---|---|---|---|---|
0 | 0.983051 | 0.962963 | 1.000000 | 1.000000 | 0.021019 |
1 | 0.991150 | 0.937500 | 1.000000 | 0.994872 | 0.021019 |
2 | 0.980000 | 0.976744 | 1.000000 | 1.000000 | 0.021019 |
3 | 0.990099 | 0.976744 | 0.999429 | 0.993197 | 0.022020 |
mean | 0.986075 | 0.963488 | 0.999857 | 0.997017 | 0.021270 |
>>> atom.plot_cv_splits()
>>> # With groups
>>> groups = choices(["A", "B", "C", "D"], k=X.shape[0])
>>> atom = ATOMClassifier(X, y, metadata={"groups": groups}, n_rows=0.2, random_state=1)
>>> atom.run("LR", metric=["f1", "auc"])
>>> atom.lr.cross_validate(cv=4)
train_f1 | test_f1 | train_auc | test_auc | time | |
---|---|---|---|---|---|
0 | 1.000000 | 0.947368 | 1.000000 | 0.981481 | 0.022021 |
1 | 0.990476 | 0.972973 | 1.000000 | 1.000000 | 0.021019 |
2 | 0.981481 | 0.941176 | 1.000000 | 1.000000 | 0.021019 |
3 | 0.982143 | 0.937500 | 0.999449 | 0.993056 | 0.021019 |
mean | 0.988525 | 0.949754 | 0.999862 | 0.993634 | 0.021270 |
>>> atom.plot_cv_splits()
>>> # For forecast models
>>> y = load_airline()
>>> atom = ATOMForecaster(y, random_state=1)
>>> atom.run("Croston", metric=["mape", "mse", "mae"])
>>> atom.croston.cross_validate(cv=4)
test_mape | test_mse | test_mae | time | |
---|---|---|---|---|
0 | 0.263562 | 4248.686875 | 57.152766 | 0.001885 |
1 | 0.170446 | 3539.682857 | 47.622303 | 0.001879 |
2 | 0.218499 | 10212.231474 | 85.077206 | 0.001953 |
3 | 0.144202 | 8875.839953 | 69.594088 | 0.001915 |
mean | 0.199177 | 6719.110290 | 64.861591 | 0.001908 |
>>> atom.plot_cv_splits()