Plots
ATOM provides many plotting methods to analyze the data or compare the model performances. Descriptions and examples can be found in the API section. ATOM mainly uses the plotly library for plotting. Plotly makes interactive, publication-quality graphs that are rendered using html. Some plots require other libraries like matplotlib, shap, wordcloud and schemdraw.
Plots that compare model performances (methods with the models
parameter) can be called directly from atom, e.g. atom.plot_roc()
,
or from one of the models, e.g. atom.adab.plot_roc()
. If called from
atom, use the models
parameter to specify which models to plot. If
called from a specific model, it makes the plot only for that model and
the models
parameter becomes unavailable.
Plots that analyze the data (methods without the models
parameter)
can only be called from atom, and not from the models.
Parameters
Apart from the plot-specific parameters, all plots have five parameters in common:
- The
title
parameter adds a title to the plot. The default value doesn't show any title. Provide a configuration (as dictionary) to customize its appearance, e.g.title=dict(text="Awesome plot", color="red")
. Read more in plotly's documentation. -
The
legend
parameter is used to show/hide, position or customize the plot's legend. Provide a configuration (as dictionary) to customize its appearance (e.g.legend=dict(title="Title for legend", title_font_color="red")
) or choose one of the following locations:- upper left
- upper right
- lower left
- lower right
- upper center
- lower center
- center left
- center right
- center
- out: Position the legend outside the axis, on the right hand side. This
is plotly's default position. Note that this shrinks the size of the axis
to fit both legend and axes in the specified
figsize
.
-
The
figsize
parameter adjust the plot's size. - The
filename
parameter is used to save the plot. - The
display
parameter determines whether to show or return the plot.
Aesthetics
The plot's aesthetics can be customized using the plot attributes, e.g.
atom.title_fontsize = 30
. The default values are:
- palette: ["rgb(0, 98, 98)", "rgb(56, 166, 165)", "rgb(115, 175, 72)", "rgb(237, 173, 8)", "rgb(225, 124, 5)", "rgb(204, 80, 62)", "rgb(148, 52, 110)", "rgb(111, 64, 112)", "rgb(102, 102, 102)"]
- title_fontsize: 24
- label_fontsize: 16
- tick_fontsize: 12
Use atom's update_layout method to further
customize the plot's aesthetics using any of plotly's layout properties,
e.g. atom.update_layout(template="plotly_dark")
. Use the reset_aesthetics
method to reset the aesthetics to their default value. See advanced plotting
for various examples.
Canvas
Use the canvas method to draw multiple plots side
by side, for example to make it easier to compare similar results. The canvas
method is a @contextmanager
, i.e. it's used through Python's with
command.
Plots in a canvas ignore the legend, figsize, filename and display parameters.
Instead, specify these parameters in the canvas. If a variable is assigned to
the canvas (e.g. with atom.canvas() as fig
), it yields the resulting figure.
For example, we can use a canvas to compare the results of a XGBoost and LightGBM model on the train and test set. We could also draw the lines for both models in the same axes, but that would clutter the plot too much. Click here for more examples.
>>> from atom import ATOMClassifier
>>> import pandas as pd
>>> X = pd.read_csv("./examples/datasets/weatherAUS.csv")
>>> atom = ATOMClassifier(X, y="RainTomorrow")
>>> atom.impute()
>>> atom.encode()
>>> atom.run(["xgb", "lgb"])
>>> with atom.canvas(2, 2, title="XGBoost vs LightGBM"):
... atom.xgb.plot_roc(dataset="both", title="ROC - XGBoost")
... atom.lgb.plot_roc(dataset="both", title="ROC - LightGBM")
... atom.xgb.plot_prc(dataset="both", title="PRC - XGBoost")
... atom.lgb.plot_prc(dataset="both", title="PRC - LightGBM")
SHAP
The SHAP (SHapley Additive exPlanations) python package uses a game theoretic approach to explain the output of any machine learning model. It connects optimal credit allocation with local explanations using the classic Shapley values from game theory and their related extensions. ATOM implements methods to plot 7 of SHAP's plotting functions directly from its API. A list of available shap plots can be found here.
Calculating the Shapley values is computationally expensive, especially for model agnostic explainers like Permutation. To avoid having to recalculate the values for every plot, ATOM stores the shapley values internally after the first calculation, and access them later when needed again.
Note
Since the plot figures are not made by ATOM, note the following:
- It's not possible to draw multiple models in the same figure.
Selecting more than one model will raise an exception. To avoid
this, call the plot directly from a model, e.g.
atom.lr.plot_shap_force()
. - The returned plot is a matplotlib figure, not plotly's.
Available plots
A list of available plots can be found hereunder. Note that not all plots can be called from every class and that their availability can depend on the task at hand.
Feature selection plots
plot_components | Plot the explained variance ratio per component. |
plot_pca | Plot the explained variance ratio vs number of components. |
plot_rfecv | Plot the rfecv results. |
Data plots
plot_correlation | Plot a correlation matrix. |
plot_distribution | Plot column distributions. |
plot_ngrams | Plot n-gram frequencies. |
plot_qq | Plot a quantile-quantile plot. |
plot_relationships | Plot pairwise relationships in a dataset. |
plot_wordcloud | Plot a wordcloud from the corpus. |
Hyperparameter tuning plots
plot_edf | Plot the Empirical Distribution Function of a study. |
plot_hyperparameter_importance | Plot a model's hyperparameter importance. |
plot_hyperparameters | Plot hyperparameter relationships in a study. |
plot_parallel_coordinate | Plot high-dimensional parameter relationships in a study. |
plot_pareto_front | Plot the Pareto front of a study. |
plot_slice | Plot the parameter relationship in a study. |
plot_trials | Plot the hyperparameter tuning trials. |
Prediction plots
plot_calibration | Plot the calibration curve for a binary classifier. |
plot_confusion_matrix | Plot a model's confusion matrix. |
plot_det | Plot the Detection Error Tradeoff curve. |
plot_errors | Plot a model's prediction errors. |
plot_evals | Plot evaluation curves. |
plot_feature_importance | Plot a model's feature importance. |
plot_gains | Plot the cumulative gains curve. |
plot_learning_curve | Plot the learning curve: score vs number of training samples. |
plot_lift | Plot the lift curve. |
plot_parshap | Plot the partial correlation of shap values. |
plot_partial_dependence | Plot the partial dependence of features. |
plot_permutation_importance | Plot the feature permutation importance of models. |
plot_pipeline | Plot a diagram of the pipeline. |
plot_prc | Plot the precision-recall curve. |
plot_probabilities | Plot the probability distribution of the target classes. |
plot_residuals | Plot a model's residuals. |
plot_results | Plot the model results. |
plot_roc | Plot the Receiver Operating Characteristics curve. |
plot_successive_halving | Plot scores per iteration of the successive halving. |
plot_threshold | Plot metric performances against threshold values. |
Shap plots
plot_shap_bar | Plot SHAP's bar plot. |
plot_shap_beeswarm | Plot SHAP's beeswarm plot. |
plot_shap_decision | Plot SHAP's decision plot. |
plot_shap_force | Plot SHAP's force plot. |
plot_shap_heatmap | Plot SHAP's heatmap plot. |
plot_shap_scatter | Plot SHAP's scatter plot. |
plot_shap_waterfall | Plot SHAP's waterfall plot. |