Predicting
After running a successful pipeline, it's possible you would like to
apply all used transformations onto new data, or make predictions using
one of the trained models. Just like a sklearn estimator, you can call
the prediction methods from a fitted trainer, e.g. atom.predict(X)
.
Calling the method without specifying a model will use the winning model
in the pipeline (under attribute winner
). To use a different model,
simply call the method from a model, e.g. atom.AdaB.predict(X)
.
All prediction methods transform the provided data through all
transformers in the current branch before making the predictions.
By default, this excludes transformers that should only be applied
on the training set, like outlier pruning and balancing the dataset.
Use the method's pipeline
parameter to customize which
transformations to apply with every call.
The available prediction methods are a selection of the most common methods for estimators in sklearn's API:
transform | Transform new data through all transformers in a branch. |
predict | Return class predictions. |
predict_proba | Return class probabilities. |
predict_log_proba | Return class log-probabilities. |
decision_function | Return confidence scores. |
score | Return a metric score. |
Except for transform, the prediction methods can be calculated on the
train and test set. You can access them through the model's prediction
attributes, e.g. atom.mnb.predict_train
or atom.mnb.predict_test
.
Keep in mind that the results are not calculated until the attribute is
called for the first time. This mechanism avoids having to calculate
attributes that are never used, saving time and memory.
Except for transform and score, it's possible to get the prediction on a
specific row or rows in the dataset providing their index names or positions,
e.g. atom.rf.predict(10)
returns the random forest's prediction on the
10th row in the dataset, or atom.rf.predict_proba(["index1", "index2"])
returns the model's class probabilities for the rows in the dataset with
indices index1
and index2
.