Skip to content

plot_scatter_matrix


method plot_scatter_matrix(columns=None, title=None, figsize=(10, 10), filename=None, display=True) [source]

Plot a matrix of scatter plots. A subset of max 250 random samples are selected from every column to not clutter the plot.

Parameters:

columns: slice, sequence or None, optional (default=None)
Slice, names or indices of the columns to plot. If None, plot all columns in the dataset. Selected categorical columns are ignored.

title: str or None, optional (default=None)
Plot's title. If None, the title is left empty.

figsize: tuple, optional (default=(10, 10))
Figure's size, format as (x, y).

filename: str or None, optional (default=None)
Name of the file. Use "auto" for automatic naming. If None, the figure is not saved.

display: bool or None, optional (default=True)
Whether to render the plot. If None, it returns the matplotlib figure.

**kwargs
Additional keyword arguments for seaborn's pairplot.

Returns: fig: matplotlib.figure.Figure
Plot object. Only returned if display=None.


Example

from atom import ATOMClassifier

atom = ATOMClassifier(X, y)
atom.plot_scatter_matrix(columns=slice(0, 5))
plot_scatter_matrix
Back to top