plot_pipeline
method plot_pipeline(models=None, draw_hyperparameter_tuning=True, color_branches=None, title=None, legend=None, figsize=None, filename=None, display=True)[source]
Plot a diagram of the pipeline.
Warning
This plot uses the schemdraw package, which is incompatible with plotly. The returned plot is therefore a matplotlib figure.
Parameters |
models: int, str, Model, segment, sequence or None, default=None
Models for which to draw the pipeline. If None, all
pipelines are plotted.
draw_hyperparameter_tuning: bool, default=True
Whether to draw if the models used Hyperparameter Tuning.
color_branches: bool or None, default=None
Whether to draw every branch in a different color. If None,
branches are colored when there is more than one.
title: str, dict or None, default=None
Title for the plot.
legend: str, dict or None, default=None
Do nothing. Implemented for continuity of the API.
figsize: tuple or None, default=None
Figure's size in pixels, format as (x, y). If None, it
adapts the size to the pipeline drawn.
filename: str, Path or None, default=None
Save the plot using this name. Use "auto" for automatic
naming. The type of the file depends on the provided name
(.html, .png, .pdf, etc...). If
display: bool or None, default=Truefilename has no file type,
the plot is saved as png. If None, the plot is not saved.
Whether to render the plot. If None, it returns the figure.
|
Returns | {#plot_pipeline-plt.Figure or None}
plt.Figure or None
Plot object. Only returned if display=None .
|
Example
>>> from atom import ATOMClassifier
>>> from sklearn.datasets import load_breast_cancer
>>> X, y = load_breast_cancer(return_X_y=True, as_frame=True)
>>> atom = ATOMClassifier(X, y, random_state=1)
>>> atom.run(["GNB", "RNN", "SGD", "MLP"])
>>> atom.voting(models=atom.winners[:2])
>>> atom.plot_pipeline()
>>> atom = ATOMClassifier(X, y, random_state=1)
>>> atom.scale()
>>> atom.prune()
>>> atom.run("RF", n_trials=30)
>>> atom.branch = "undersample"
>>> atom.balance("nearmiss")
>>> atom.run("RF_undersample")
>>> atom.branch = "oversample_from_main"
>>> atom.balance("smote")
>>> atom.run("RF_oversample")
>>> atom.plot_pipeline()