Plots
ATOM provides many plotting methods to analyze the data or compare the model performances. Descriptions and examples can be found in the API section. ATOM mainly uses the plotly library for plotting. Plotly makes interactive, publication-quality graphs that are rendered using html. Some plots require other libraries like matplotlib, shap, wordcloud and schemdraw.
Plots that compare model performances (methods with the models
parameter) can be called directly from atom, e.g., atom.plot_roc()
,
or from one of the models, e.g., atom.adab.plot_roc()
. If called from
atom, use the models
parameter to specify which models to plot. If
called from a specific model, it makes the plot only for that model and
the models
parameter becomes unavailable.
Plots that analyze the data (methods without the models
parameter)
can only be called from atom, and not from the models.
Parameters
Apart from the plot-specific parameters, all plots have five parameters in common:
- The
title
parameter adds a title to the plot. The default value doesn't show any title. Provide a configuration (as dictionary) to customize its appearance, e.g.,title=dict(text="Awesome plot", color="red")
. Read more in plotly's documentation. -
The
legend
parameter is used to show/hide, position or customize the plot's legend. Provide a configuration (as dictionary) to customize its appearance (e.g.,legend=dict(title="Title for legend", title_font_color="red")
) or choose one of the following locations:- upper left
- upper right
- lower left
- lower right
- upper center
- lower center
- center left
- center right
- center
- out: Position the legend outside the axis, on the right hand side. This
is plotly's default position. Note that this shrinks the size of the axis
to fit both legend and axes in the specified
figsize
.
-
The
figsize
parameter adjust the plot's size. - The
filename
parameter is used to save the plot. - The
display
parameter determines whether to show or return the plot.
Info
In some plotting methods, it's possible to plot separate
lines for different subsets of the rows. For example, to compare the results
on the train and test set. For these cases, either provide a sequence to the
rows
parameter for every line you want to draw, e.g., atom.plot_roc(rows=("train", "test"))
,
or provide a dictionary where the keys are the names of the sets (used in the
legend) and the values are the corresponding selection of rows, selected using
any of the aforementioned approaches, e.g, atom.plot_roc(rows={"0-99": range(100), "100-199": range(100, 200})
.
Note that for these methods, using atom.plot_roc(rows="train+test")
,
only plots one line with the data from both sets. See the
advanced plotting example.
Aesthetics
The plot's aesthetics can be customized using the plot attributes prior
to calling the plotting method, e.g., atom.title_fontsize = 30
.
The default values are:
- palette: ["rgb(0, 98, 98)", "rgb(56, 166, 165)", "rgb(115, 175, 72)", "rgb(237, 173, 8)", "rgb(225, 124, 5)", "rgb(204, 80, 62)", "rgb(148, 52, 110)", "rgb(111, 64, 112)", "rgb(102, 102, 102)"]
- title_fontsize: 24
- label_fontsize: 16
- tick_fontsize: 12
Use atom's update_layout method to further
customize the plot's layout using any of plotly's layout properties,
e.g., atom.update_layout(template="plotly_dark")
. Similarly, use
the update_traces method to customize the
traces properties, e.g.
atom.update_traces(mode="lines+markers")
.
The reset_aesthetics method allows you to reset all aesthetics to their default value. See advanced plotting for an example.
Canvas
Use the canvas method to draw multiple plots side
by side, for example to make it easier to compare similar results. The canvas
method is a @contextmanager
, i.e., it's used through Python's with
command.
Plots in a canvas ignore the legend, figsize, filename and display parameters.
Instead, specify these parameters in the canvas. If a variable is assigned to
the canvas (e.g., with atom.canvas() as fig
), it yields the resulting
figure.
For example, we can use a canvas to compare the results of a XGBoost and LightGBM model on the train and test set. We could also draw the lines for both models in the same axes, but that would clutter the plot too much. Click here for more examples.
from atom import ATOMClassifier
from sklearn.datasets import make_classification
X, y = make_classification(n_samples=1000, flip_y=0.2, random_state=1)
atom = ATOMClassifier(X, y)
atom.run(["XGB", "LGB"])
with atom.canvas(2, 2, title="XGBoost vs LightGBM"):
atom.xgb.plot_roc(rows="train+test", title="ROC - XGBoost")
atom.lgb.plot_roc(rows="train+test", title="ROC - LightGBM")
atom.xgb.plot_prc(rows="train+test", title="PRC - XGBoost")
atom.lgb.plot_prc(rows="train+test", title="PRC - LightGBM")
SHAP
The SHAP (SHapley Additive exPlanations) python package uses a game theoretic approach to explain the output of any machine learning model. It connects optimal credit allocation with local explanations using the classic Shapley values from game theory and their related extensions. ATOM implements methods to plot 7 of SHAP's plotting functions directly from its API. A list of available shap plots can be found here.
Calculating the Shapley values is computationally expensive, especially for model agnostic explainers like Permutation. To avoid having to recalculate the values for every plot, ATOM stores the shapley values internally after the first calculation, and access them later when needed again.
Warning
- It's not possible to draw multiple models in the same figure.
Selecting more than one model will raise an exception. To avoid
this, call the plot directly from a model, e.g.,
atom.lr.plot_shap_force()
. - The returned plot is a matplotlib figure, not plotly's.
- SHAP plots aren't available for forecast tasks.
Available plots
A list of available plots can be found hereunder. Note that not all plots can be called from every class and that their availability can depend on the task at hand.
Data plots
plot_acf | Plot the autocorrelation function. |
plot_ccf | Plot the cross-correlation between two time series. |
plot_components | Plot the explained variance ratio per component. |
plot_correlation | Plot a correlation matrix. |
plot_data_splits | Visualize the data splits. |
plot_decomposition | Plot the trend, seasonality and residuals of a time series. |
plot_distribution | Plot column distributions. |
plot_fft | Plot the fourier transformation of a time series. |
plot_ngrams | Plot n-gram frequencies. |
plot_pacf | Plot the partial autocorrelation function. |
plot_pca | Plot the explained variance ratio vs number of components. |
plot_periodogram | Plot the spectral density of a time series. |
plot_qq | Plot a quantile-quantile plot. |
plot_relationships | Plot pairwise relationships in a dataset. |
plot_rfecv | Plot the rfecv results. |
plot_series | Plot a data series. |
plot_wordcloud | Plot a wordcloud from the corpus. |
Hyperparameter tuning plots
plot_edf | Plot the Empirical Distribution Function of a study. |
plot_hyperparameter_importance | Plot a model's hyperparameter importance. |
plot_hyperparameters | Plot hyperparameter relationships in a study. |
plot_parallel_coordinate | Plot high-dimensional parameter relationships in a study. |
plot_pareto_front | Plot the Pareto front of a study. |
plot_slice | Plot the parameter relationship in a study. |
plot_terminator_improvement | Plot the potentials for future objective improvement. |
plot_timeline | Plot the timeline of a study. |
plot_trials | Plot the hyperparameter tuning trials. |
Prediction plots
plot_bootstrap | Plot the bootstrapping scores. |
plot_calibration | Plot the calibration curve for a binary classifier. |
plot_confusion_matrix | Plot a model's confusion matrix. |
plot_cv_splits | Visualize the cross-validation splits. |
plot_det | Plot the Detection Error Tradeoff curve. |
plot_errors | Plot a model's prediction errors. |
plot_evals | Plot evaluation curves. |
plot_feature_importance | Plot a model's feature importance. |
plot_forecast | Plot model forecasts for the target time series. |
plot_gains | Plot the cumulative gains curve. |
plot_learning_curve | Plot the learning curve: score vs number of training samples. |
plot_lift | Plot the lift curve. |
plot_parshap | Plot the partial correlation of shap values. |
plot_partial_dependence | Plot the partial dependence of features. |
plot_permutation_importance | Plot the feature permutation importance of models. |
plot_pipeline | Plot a diagram of the pipeline. |
plot_prc | Plot the precision-recall curve. |
plot_probabilities | Plot the probability distribution of the target classes. |
plot_residuals | Plot a model's residuals. |
plot_results | Compare metric results of the models. |
plot_roc | Plot the Receiver Operating Characteristics curve. |
plot_successive_halving | Plot scores per iteration of the successive halving. |
plot_threshold | Plot metric performances against threshold values. |
Shap plots
plot_shap_bar | Plot SHAP's bar plot. |
plot_shap_beeswarm | Plot SHAP's beeswarm plot. |
plot_shap_decision | Plot SHAP's decision plot. |
plot_shap_force | Plot SHAP's force plot. |
plot_shap_heatmap | Plot SHAP's heatmap plot. |
plot_shap_scatter | Plot SHAP's scatter plot. |
plot_shap_waterfall | Plot SHAP's waterfall plot. |