Predicting
After training a model, you probably want to make predictions on new,
unseen data. Just like a sklearn estimator, you can call the prediction
methods from the model, e.g., atom.tree.predict(X)
.
All prediction methods transform the provided data through the pipeline in the model's branch before making the predictions. Transformers that should only be applied on the training set are excluded from this step (e.g., outlier pruning or class balancing).
The available prediction methods are the standard methods for predictors in sklearn's and sktime's API.
For classification and regression tasks:
decision_function | Get confidence scores on new data or existing rows. |
predict | Get predictions on new data or existing rows. |
predict_log_proba | Get class log-probabilities on new data or existing rows. |
predict_proba | Get class probabilities on new data or existing rows. |
score | Get a metric score on new data. |
For forecast tasks:
predict | Get predictions on new data or existing rows. |
predict_interval | Get prediction intervals on new data or existing rows. |
predict_proba | Get probabilistic forecasts on new data or existing rows. |
predict_quantiles | Get quantile forecasts on new data or existing rows. |
predict_var | Get variance forecasts on new data or existing rows. |
score | Get a metric score on new data. |
Warning
The score
method return atom's metric score, not the metric returned
by sklearn/sktime's score method for predictors. Use the method's
metric
parameter to calculate a different metric.
Note
- The output of ATOM's methods are pandas objects, not numpy arrays.
- The
predict_proba
method of some meta-estimators for multioutput tasks (such as MultioutputClassifier) return 3 dimensions, namely, a list of arrays with shape=(n_samples, n_classes). One array per target column. Since ATOM's prediction methods return pandas objects, such 3-dimensional arrays are converted to a multiindex pd.DataFrame, where the first level of the row indices are the target columns, and the second level are the classes. - The prediction results are cached after the first call to avoid consequent expensive calculations. This mechanism can increase the size of the instance for large datasets. Use the clear method to free the memory.
It's also possible to get the prediction for a specific row or rows in
the dataset. See the row and column selection section in the user guide
to learn how to select the rows, e.g., atom.rf.predict("test")
or atom.rf.predict_proba(range(100))
.
Note
For forecast models, prediction on rows follow the ForecastingHorizon
API. That means that using the row index works, but for example using
atom.arima.predict(1)
returns the prediction on the first row
of the test set (instead of the second row of the train set).