{
"cells": [
{
"cell_type": "markdown",
"id": "9d021c80-8290-4533-9f09-71b0edba468d",
"metadata": {},
"source": [
"# Example: Accelerating pipelines on GPU\n",
"----------------------------------------\n",
"\n",
"This example shows how to accelerate a pipeline on GPU using cuML.\n",
"\n",
"The data used is a synthetic dataset created using sklearn's [make_classification](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_classification.html) function."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "b57d1dff-114e-4f92-b2a5-536aaf5ce2ad",
"metadata": {},
"outputs": [],
"source": [
"from atom import ATOMClassifier\n",
"from sklearn.datasets import make_classification\n",
"\n",
"# Create a dummy dataset\n",
"X, y = make_classification(n_samples=100000, n_features=40)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "da7bcaba-27ed-4d19-92c8-826ed808fe5c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<< ================== ATOM ================== >>\n",
"Algorithm task: binary classification.\n",
"GPU training enabled.\n",
"Backend engine: cuml.\n",
"\n",
"Dataset stats ==================== >>\n",
"Shape: (100000, 41)\n",
"Memory: 32.80 MB\n",
"Scaled: True\n",
"Outlier values: 8127 (0.2%)\n",
"-------------------------------------\n",
"Train set size: 80000\n",
"Test set size: 20000\n",
"-------------------------------------\n",
"| | dataset | train | test |\n",
"| - | ------------- | ------------- | ------------- |\n",
"| 0 | 50006 (1.0) | 40005 (1.0) | 10001 (1.0) |\n",
"| 1 | 49994 (1.0) | 39995 (1.0) | 9999 (1.0) |\n",
"\n"
]
}
],
"source": [
"atom = ATOMClassifier(X, y, device=\"gpu\", engine=\"cuml\", verbose=2)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "4f986f25-d52f-45e9-b3e2-e645204363d1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting Scaler...\n",
"Scaling features...\n"
]
}
],
"source": [
"atom.scale()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "72caeba1-d3a0-4ea1-8eea-2d2cdb4ba352",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" x0 | \n",
" x1 | \n",
" x2 | \n",
" x3 | \n",
" x4 | \n",
" x5 | \n",
" x6 | \n",
" x7 | \n",
" x8 | \n",
" x9 | \n",
" ... | \n",
" x31 | \n",
" x32 | \n",
" x33 | \n",
" x34 | \n",
" x35 | \n",
" x36 | \n",
" x37 | \n",
" x38 | \n",
" x39 | \n",
" target | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 2.021646 | \n",
" -0.634557 | \n",
" -0.867811 | \n",
" 1.103642 | \n",
" 1.559011 | \n",
" 0.122284 | \n",
" -0.864821 | \n",
" 1.411657 | \n",
" 0.147997 | \n",
" -2.269082 | \n",
" ... | \n",
" -0.489864 | \n",
" 1.861048 | \n",
" -0.353861 | \n",
" 0.720823 | \n",
" -1.522117 | \n",
" -0.737707 | \n",
" -1.573936 | \n",
" -0.832174 | \n",
" 0.203154 | \n",
" 0 | \n",
"
\n",
" \n",
" 1 | \n",
" -0.019885 | \n",
" 0.846568 | \n",
" -0.364059 | \n",
" -1.091604 | \n",
" -1.336692 | \n",
" 0.186689 | \n",
" -0.274142 | \n",
" 0.020563 | \n",
" 0.693235 | \n",
" -1.908658 | \n",
" ... | \n",
" -1.610058 | \n",
" -0.365231 | \n",
" 0.284908 | \n",
" 0.170156 | \n",
" -0.236553 | \n",
" -0.573761 | \n",
" -0.107317 | \n",
" -2.480178 | \n",
" 0.420341 | \n",
" 0 | \n",
"
\n",
" \n",
" 2 | \n",
" 0.516618 | \n",
" -0.013420 | \n",
" -0.753879 | \n",
" -0.488243 | \n",
" 0.560051 | \n",
" 0.395817 | \n",
" -0.522523 | \n",
" -1.083503 | \n",
" -0.073398 | \n",
" 0.383061 | \n",
" ... | \n",
" 0.966283 | \n",
" 1.405546 | \n",
" -0.658654 | \n",
" 0.339090 | \n",
" -1.615997 | \n",
" -1.312444 | \n",
" 0.984578 | \n",
" 0.602858 | \n",
" -1.110684 | \n",
" 1 | \n",
"
\n",
" \n",
" 3 | \n",
" 0.111861 | \n",
" -0.966334 | \n",
" 0.208509 | \n",
" 0.494328 | \n",
" -0.766835 | \n",
" -0.003399 | \n",
" -0.500449 | \n",
" -0.530622 | \n",
" -0.481663 | \n",
" -1.146132 | \n",
" ... | \n",
" -0.304896 | \n",
" 2.030211 | \n",
" -1.189488 | \n",
" -1.238600 | \n",
" 1.658765 | \n",
" -0.255644 | \n",
" 0.572194 | \n",
" 0.195496 | \n",
" 0.617734 | \n",
" 1 | \n",
"
\n",
" \n",
" 4 | \n",
" 0.160135 | \n",
" -0.873517 | \n",
" 0.719142 | \n",
" -2.020767 | \n",
" 0.421435 | \n",
" -1.941230 | \n",
" 0.835615 | \n",
" -1.178845 | \n",
" 0.235273 | \n",
" -0.328574 | \n",
" ... | \n",
" 1.633662 | \n",
" -0.631118 | \n",
" 1.814046 | \n",
" 1.031754 | \n",
" 0.328665 | \n",
" 1.704483 | \n",
" 2.153710 | \n",
" -1.430552 | \n",
" -0.543915 | \n",
" 1 | \n",
"
\n",
" \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" 99995 | \n",
" 1.100240 | \n",
" 0.092581 | \n",
" -0.346265 | \n",
" 0.234024 | \n",
" 0.590199 | \n",
" 0.755019 | \n",
" -1.688456 | \n",
" -1.031070 | \n",
" -0.620193 | \n",
" -0.283336 | \n",
" ... | \n",
" 0.356480 | \n",
" 1.346821 | \n",
" -0.299087 | \n",
" 2.343587 | \n",
" -2.003646 | \n",
" -0.933179 | \n",
" 0.764255 | \n",
" -0.233526 | \n",
" -1.462311 | \n",
" 1 | \n",
"
\n",
" \n",
" 99996 | \n",
" -1.142596 | \n",
" 0.321843 | \n",
" -0.974006 | \n",
" 0.390418 | \n",
" 0.404722 | \n",
" -0.324256 | \n",
" -0.288176 | \n",
" 1.009458 | \n",
" 0.860912 | \n",
" -0.191313 | \n",
" ... | \n",
" 0.044618 | \n",
" -2.030135 | \n",
" 1.448640 | \n",
" -0.854798 | \n",
" 1.441451 | \n",
" 1.347461 | \n",
" -0.937607 | \n",
" 0.572504 | \n",
" -0.787673 | \n",
" 0 | \n",
"
\n",
" \n",
" 99997 | \n",
" 1.658252 | \n",
" 0.303637 | \n",
" -0.020324 | \n",
" 0.225917 | \n",
" 0.154092 | \n",
" -1.208507 | \n",
" -0.199919 | \n",
" 1.063016 | \n",
" -0.395696 | \n",
" -0.060886 | \n",
" ... | \n",
" 1.563345 | \n",
" -1.261853 | \n",
" -0.810122 | \n",
" -0.503823 | \n",
" 1.565602 | \n",
" -1.264792 | \n",
" -0.591644 | \n",
" 1.588397 | \n",
" 0.601721 | \n",
" 0 | \n",
"
\n",
" \n",
" 99998 | \n",
" -0.288042 | \n",
" -1.139792 | \n",
" 1.548338 | \n",
" 0.501413 | \n",
" 0.361604 | \n",
" -0.315720 | \n",
" -0.564607 | \n",
" 1.500870 | \n",
" 0.501768 | \n",
" 0.649079 | \n",
" ... | \n",
" 0.344663 | \n",
" 1.734476 | \n",
" 0.660177 | \n",
" 0.767554 | \n",
" 1.461940 | \n",
" 0.310189 | \n",
" -1.469978 | \n",
" 0.900132 | \n",
" 1.114330 | \n",
" 0 | \n",
"
\n",
" \n",
" 99999 | \n",
" -3.093351 | \n",
" -0.636463 | \n",
" -0.449575 | \n",
" 1.169980 | \n",
" -1.041870 | \n",
" -0.257173 | \n",
" 2.072777 | \n",
" -0.101111 | \n",
" -0.956916 | \n",
" -0.251162 | \n",
" ... | \n",
" 2.250647 | \n",
" 0.746250 | \n",
" -0.610311 | \n",
" 0.445467 | \n",
" -0.636288 | \n",
" -0.187444 | \n",
" 0.226108 | \n",
" -0.186927 | \n",
" -1.024960 | \n",
" 1 | \n",
"
\n",
" \n",
"
\n",
"
100000 rows × 41 columns
\n",
"
"
],
"text/plain": [
" x0 x1 x2 x3 x4 x5 x6 \\\n",
"0 2.021646 -0.634557 -0.867811 1.103642 1.559011 0.122284 -0.864821 \n",
"1 -0.019885 0.846568 -0.364059 -1.091604 -1.336692 0.186689 -0.274142 \n",
"2 0.516618 -0.013420 -0.753879 -0.488243 0.560051 0.395817 -0.522523 \n",
"3 0.111861 -0.966334 0.208509 0.494328 -0.766835 -0.003399 -0.500449 \n",
"4 0.160135 -0.873517 0.719142 -2.020767 0.421435 -1.941230 0.835615 \n",
"... ... ... ... ... ... ... ... \n",
"99995 1.100240 0.092581 -0.346265 0.234024 0.590199 0.755019 -1.688456 \n",
"99996 -1.142596 0.321843 -0.974006 0.390418 0.404722 -0.324256 -0.288176 \n",
"99997 1.658252 0.303637 -0.020324 0.225917 0.154092 -1.208507 -0.199919 \n",
"99998 -0.288042 -1.139792 1.548338 0.501413 0.361604 -0.315720 -0.564607 \n",
"99999 -3.093351 -0.636463 -0.449575 1.169980 -1.041870 -0.257173 2.072777 \n",
"\n",
" x7 x8 x9 ... x31 x32 x33 \\\n",
"0 1.411657 0.147997 -2.269082 ... -0.489864 1.861048 -0.353861 \n",
"1 0.020563 0.693235 -1.908658 ... -1.610058 -0.365231 0.284908 \n",
"2 -1.083503 -0.073398 0.383061 ... 0.966283 1.405546 -0.658654 \n",
"3 -0.530622 -0.481663 -1.146132 ... -0.304896 2.030211 -1.189488 \n",
"4 -1.178845 0.235273 -0.328574 ... 1.633662 -0.631118 1.814046 \n",
"... ... ... ... ... ... ... ... \n",
"99995 -1.031070 -0.620193 -0.283336 ... 0.356480 1.346821 -0.299087 \n",
"99996 1.009458 0.860912 -0.191313 ... 0.044618 -2.030135 1.448640 \n",
"99997 1.063016 -0.395696 -0.060886 ... 1.563345 -1.261853 -0.810122 \n",
"99998 1.500870 0.501768 0.649079 ... 0.344663 1.734476 0.660177 \n",
"99999 -0.101111 -0.956916 -0.251162 ... 2.250647 0.746250 -0.610311 \n",
"\n",
" x34 x35 x36 x37 x38 x39 target \n",
"0 0.720823 -1.522117 -0.737707 -1.573936 -0.832174 0.203154 0 \n",
"1 0.170156 -0.236553 -0.573761 -0.107317 -2.480178 0.420341 0 \n",
"2 0.339090 -1.615997 -1.312444 0.984578 0.602858 -1.110684 1 \n",
"3 -1.238600 1.658765 -0.255644 0.572194 0.195496 0.617734 1 \n",
"4 1.031754 0.328665 1.704483 2.153710 -1.430552 -0.543915 1 \n",
"... ... ... ... ... ... ... ... \n",
"99995 2.343587 -2.003646 -0.933179 0.764255 -0.233526 -1.462311 1 \n",
"99996 -0.854798 1.441451 1.347461 -0.937607 0.572504 -0.787673 0 \n",
"99997 -0.503823 1.565602 -1.264792 -0.591644 1.588397 0.601721 0 \n",
"99998 0.767554 1.461940 0.310189 -1.469978 0.900132 1.114330 0 \n",
"99999 0.445467 -0.636288 -0.187444 0.226108 -0.186927 -1.024960 1 \n",
"\n",
"[100000 rows x 41 columns]"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"atom.dataset"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "96f5ee24-5f1a-4293-98d1-a1896c7ad1e6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Scaler used: StandardScaler()\n",
"Scaler's module: cuml._thirdparty.sklearn.preprocessing._data\n"
]
}
],
"source": [
"print(f\"Scaler used: {atom.standard}\")\n",
"print(f\"Scaler's module: {atom.standard.__class__.__module__}\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "a7c74dcf-a570-488a-bdcc-2d5bb234e887",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Training ========================= >>\n",
"Models: RF, SGD, XGB\n",
"Metric: f1\n",
"\n",
"\n",
"Results for RandomForest:\n",
"Fit ---------------------------------------------\n",
"Train evaluation --> f1: 0.9726\n",
"Test evaluation --> f1: 0.9431\n",
"Time elapsed: 1.935s\n",
"-------------------------------------------------\n",
"Total time: 1.935s\n",
"\n",
"\n",
"Results for StochasticGradientDescent:\n",
"Fit ---------------------------------------------\n",
"Train evaluation --> f1: 0.9236\n",
"Test evaluation --> f1: 0.9219\n",
"Time elapsed: 02m:16s\n",
"-------------------------------------------------\n",
"Total time: 02m:16s\n",
"\n",
"\n",
"Results for XGBoost:\n",
"Fit ---------------------------------------------\n",
"Train evaluation --> f1: 0.9749\n",
"Test evaluation --> f1: 0.9437\n",
"Time elapsed: 6.394s\n",
"-------------------------------------------------\n",
"Total time: 6.394s\n",
"\n",
"\n",
"Final results ==================== >>\n",
"Total time: 02m:24s\n",
"-------------------------------------\n",
"RandomForest --> f1: 0.9431\n",
"StochasticGradientDescent --> f1: 0.9219\n",
"XGBoost --> f1: 0.9437 !\n"
]
}
],
"source": [
"atom.run(models=[\"RF\", \"SGD\", \"XGB\"])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "6a86a67e-9067-454f-bc92-dafb955e42bf",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" score_train | \n",
" score_test | \n",
" time_fit | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" RF | \n",
" 0.9726 | \n",
" 0.9431 | \n",
" 1.934512 | \n",
" 1.934512 | \n",
"
\n",
" \n",
" SGD | \n",
" 0.9236 | \n",
" 0.9219 | \n",
" 135.871493 | \n",
" 135.871493 | \n",
"
\n",
" \n",
" XGB | \n",
" 0.9749 | \n",
" 0.9437 | \n",
" 6.394416 | \n",
" 6.394416 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" score_train score_test time_fit time\n",
"RF 0.9726 0.9431 1.934512 1.934512\n",
"SGD 0.9236 0.9219 135.871493 135.871493\n",
"XGB 0.9749 0.9437 6.394416 6.394416"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"atom.results"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "582646d5-49f9-451f-a85e-ea67c849e83d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"RF's module: cuml.ensemble.randomforestclassifier\n",
"SGD's module: sklearn.linear_model._stochastic_gradient\n",
"XGB's module: xgboost.sklearn\n"
]
}
],
"source": [
"for m in atom.models:\n",
" print(f\"{m}'s module: {atom[m].estimator.__class__.__module__}\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "8f893e9e-9702-45f4-b845-a60f777d3015",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" accuracy | \n",
" average_precision | \n",
" balanced_accuracy | \n",
" f1 | \n",
" jaccard | \n",
" matthews_corrcoef | \n",
" precision | \n",
" recall | \n",
" roc_auc | \n",
"
\n",
" \n",
" \n",
" \n",
" RF | \n",
" 0.9429 | \n",
" 0.9741 | \n",
" 0.9429 | \n",
" 0.9431 | \n",
" 0.8924 | \n",
" 0.8858 | \n",
" 0.9391 | \n",
" 0.9472 | \n",
" 0.9792 | \n",
"
\n",
" \n",
" SGD | \n",
" 0.9217 | \n",
" 0.9635 | \n",
" 0.9218 | \n",
" 0.9219 | \n",
" 0.8551 | \n",
" 0.8435 | \n",
" 0.9203 | \n",
" 0.9235 | \n",
" 0.9676 | \n",
"
\n",
" \n",
" XGB | \n",
" 0.9434 | \n",
" 0.9753 | \n",
" 0.9434 | \n",
" 0.9437 | \n",
" 0.8933 | \n",
" 0.8868 | \n",
" 0.9385 | \n",
" 0.9489 | \n",
" 0.9798 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" accuracy average_precision balanced_accuracy f1 jaccard \\\n",
"RF 0.9429 0.9741 0.9429 0.9431 0.8924 \n",
"SGD 0.9217 0.9635 0.9218 0.9219 0.8551 \n",
"XGB 0.9434 0.9753 0.9434 0.9437 0.8933 \n",
"\n",
" matthews_corrcoef precision recall roc_auc \n",
"RF 0.8858 0.9391 0.9472 0.9792 \n",
"SGD 0.8435 0.9203 0.9235 0.9676 \n",
"XGB 0.8868 0.9385 0.9489 0.9798 "
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"atom.evaluate()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "rapids-22.08:Python",
"language": "python",
"name": "conda-env-rapids-22.08-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}