Skip to content

plot_relationships


method plot_relationships(columns=(0, 1, 2), title=None, legend=None, figsize=(900, 900), filename=None, display=True)[source]
Plot pairwise relationships in a dataset.

Creates a grid of axes such that each numerical column appears once on the x-axes and once on the y-axes. The bottom triangle contains scatter plots (max 250 random samples), the diagonal plots contain column distributions, and the upper triangle contains contour histograms for all samples in the columns.

Parameterscolumns: slice or sequence, default=(0, 1, 2)
Columns to plot. Selected categorical columns are ignored.

title: str, dict or None, default=None
Title for the plot.

legend: str, dict or None, default=None
Does nothing. Implemented for continuity of the API.

figsize: tuple, default=(900, 900)
Figure's size in pixels, format as (x, y).

filename: str or None, default=None
Save the plot using this name. Use "auto" for automatic naming. The type of the file depends on the provided name (.html, .png, .pdf, etc...). If filename has no file type, the plot is saved as html. If None, the plot is not saved.

display: bool or None, default=True
Whether to render the plot. If None, it returns the figure.

Returnsgo.Figure or None
Plot object. Only returned if display=None.


See Also

plot_correlation

Plot a correlation matrix.

plot_distribution

Plot column distributions.

plot_qq

Plot a quantile-quantile plot.


Example

>>> from atom import ATOMClassifier
>>> from sklearn.datasets import load_breast_cancer

>>> X, y = load_breast_cancer(return_X_y=True, as_frame=True)

>>> atom = ATOMClassifier(X, y)
>>> atom.plot_relationships(columns=[0, 4, 5])