plot_data_splits
method plot_data_splits(title=None, legend="upper right", figsize=(900, 600), filename=None, display=True)[source]
Visualize the data splits.
Plots the train/test/holdout splits. The x-axis shows the number of rows, where every point corresponds to the n-th sample. Additionally, class labels and groups are plotted when relevant.
Parameters |
title: str, dict or None, default=None
Title for the plot.
legend: str, dict or None, default="upper right"
Legend for the plot. See the user guide for
an extended description of the choices.
figsize: tuple, default=(900, 600)
Figure's size in pixels, format as (x, y).
filename: str, Path or None, default=None
Save the plot using this name. Use "auto" for automatic
naming. The type of the file depends on the provided name
(.html, .png, .pdf, etc...). If
display: bool or None, default=Truefilename has no file type,
the plot is saved as html. If None, the plot is not saved.
Whether to render the plot. If None, it returns the figure.
|
Returns | {#plot_data_splits-go.Figure or None}
go.Figure or None
Plot object. Only returned if display=None .
|
See Also
Visualize the cross-validation splits.
Plot the trend, seasonality and residuals of a time series.
Plot pairwise relationships in a dataset.
Example
>>> from atom import ATOMClassifier, ATOMForecaster
>>> from random import choices
>>> from sklearn.datasets import load_breast_cancer
>>> from sktime.datasets import load_airline
>>> X, y = load_breast_cancer(return_X_y=True, as_frame=True)
>>> groups = choices(["A", "B", "C", "D"], k=X.shape[0])
>>> atom = ATOMClassifier(
... X,
... y=y,
... metadata={"groups": groups},
... n_rows=0.2,
... holdout_size=0.1,
... random_state=1,
... )
>>> atom.run("LR")
>>> atom.plot_data_splits()