{ "cells": [ { "cell_type": "markdown", "id": "9d021c80-8290-4533-9f09-71b0edba468d", "metadata": {}, "source": [ "# Example: Accelerating pipelines on GPU\n", "----------------------------------------\n", "\n", "This example shows how to accelerate a pipeline on GPU using cuML.\n", "\n", "The data used is a synthetic dataset created using sklearn's [make_classification](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_classification.html) function." ] }, { "cell_type": "code", "execution_count": 1, "id": "b57d1dff-114e-4f92-b2a5-536aaf5ce2ad", "metadata": {}, "outputs": [], "source": [ "from atom import ATOMClassifier\n", "from sklearn.datasets import make_classification\n", "\n", "# Create a dummy dataset\n", "X, y = make_classification(n_samples=100000, n_features=40)" ] }, { "cell_type": "code", "execution_count": 2, "id": "da7bcaba-27ed-4d19-92c8-826ed808fe5c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<< ================== ATOM ================== >>\n", "Algorithm task: binary classification.\n", "GPU training enabled.\n", "Backend engine: cuml.\n", "\n", "Dataset stats ==================== >>\n", "Shape: (100000, 41)\n", "Memory: 32.80 MB\n", "Scaled: True\n", "Outlier values: 8127 (0.2%)\n", "-------------------------------------\n", "Train set size: 80000\n", "Test set size: 20000\n", "-------------------------------------\n", "| | dataset | train | test |\n", "| - | ------------- | ------------- | ------------- |\n", "| 0 | 50006 (1.0) | 40005 (1.0) | 10001 (1.0) |\n", "| 1 | 49994 (1.0) | 39995 (1.0) | 9999 (1.0) |\n", "\n" ] } ], "source": [ "atom = ATOMClassifier(X, y, device=\"gpu\", engine=\"cuml\", verbose=2)" ] }, { "cell_type": "code", "execution_count": 3, "id": "4f986f25-d52f-45e9-b3e2-e645204363d1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Fitting Scaler...\n", "Scaling features...\n" ] } ], "source": [ "atom.scale()" ] }, { "cell_type": "code", "execution_count": 13, "id": "72caeba1-d3a0-4ea1-8eea-2d2cdb4ba352", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
x0x1x2x3x4x5x6x7x8x9...x31x32x33x34x35x36x37x38x39target
02.021646-0.634557-0.8678111.1036421.5590110.122284-0.8648211.4116570.147997-2.269082...-0.4898641.861048-0.3538610.720823-1.522117-0.737707-1.573936-0.8321740.2031540
1-0.0198850.846568-0.364059-1.091604-1.3366920.186689-0.2741420.0205630.693235-1.908658...-1.610058-0.3652310.2849080.170156-0.236553-0.573761-0.107317-2.4801780.4203410
20.516618-0.013420-0.753879-0.4882430.5600510.395817-0.522523-1.083503-0.0733980.383061...0.9662831.405546-0.6586540.339090-1.615997-1.3124440.9845780.602858-1.1106841
30.111861-0.9663340.2085090.494328-0.766835-0.003399-0.500449-0.530622-0.481663-1.146132...-0.3048962.030211-1.189488-1.2386001.658765-0.2556440.5721940.1954960.6177341
40.160135-0.8735170.719142-2.0207670.421435-1.9412300.835615-1.1788450.235273-0.328574...1.633662-0.6311181.8140461.0317540.3286651.7044832.153710-1.430552-0.5439151
..................................................................
999951.1002400.092581-0.3462650.2340240.5901990.755019-1.688456-1.031070-0.620193-0.283336...0.3564801.346821-0.2990872.343587-2.003646-0.9331790.764255-0.233526-1.4623111
99996-1.1425960.321843-0.9740060.3904180.404722-0.324256-0.2881761.0094580.860912-0.191313...0.044618-2.0301351.448640-0.8547981.4414511.347461-0.9376070.572504-0.7876730
999971.6582520.303637-0.0203240.2259170.154092-1.208507-0.1999191.063016-0.395696-0.060886...1.563345-1.261853-0.810122-0.5038231.565602-1.264792-0.5916441.5883970.6017210
99998-0.288042-1.1397921.5483380.5014130.361604-0.315720-0.5646071.5008700.5017680.649079...0.3446631.7344760.6601770.7675541.4619400.310189-1.4699780.9001321.1143300
99999-3.093351-0.636463-0.4495751.169980-1.041870-0.2571732.072777-0.101111-0.956916-0.251162...2.2506470.746250-0.6103110.445467-0.636288-0.1874440.226108-0.186927-1.0249601
\n", "

100000 rows × 41 columns

\n", "
" ], "text/plain": [ " x0 x1 x2 x3 x4 x5 x6 \\\n", "0 2.021646 -0.634557 -0.867811 1.103642 1.559011 0.122284 -0.864821 \n", "1 -0.019885 0.846568 -0.364059 -1.091604 -1.336692 0.186689 -0.274142 \n", "2 0.516618 -0.013420 -0.753879 -0.488243 0.560051 0.395817 -0.522523 \n", "3 0.111861 -0.966334 0.208509 0.494328 -0.766835 -0.003399 -0.500449 \n", "4 0.160135 -0.873517 0.719142 -2.020767 0.421435 -1.941230 0.835615 \n", "... ... ... ... ... ... ... ... \n", "99995 1.100240 0.092581 -0.346265 0.234024 0.590199 0.755019 -1.688456 \n", "99996 -1.142596 0.321843 -0.974006 0.390418 0.404722 -0.324256 -0.288176 \n", "99997 1.658252 0.303637 -0.020324 0.225917 0.154092 -1.208507 -0.199919 \n", "99998 -0.288042 -1.139792 1.548338 0.501413 0.361604 -0.315720 -0.564607 \n", "99999 -3.093351 -0.636463 -0.449575 1.169980 -1.041870 -0.257173 2.072777 \n", "\n", " x7 x8 x9 ... x31 x32 x33 \\\n", "0 1.411657 0.147997 -2.269082 ... -0.489864 1.861048 -0.353861 \n", "1 0.020563 0.693235 -1.908658 ... -1.610058 -0.365231 0.284908 \n", "2 -1.083503 -0.073398 0.383061 ... 0.966283 1.405546 -0.658654 \n", "3 -0.530622 -0.481663 -1.146132 ... -0.304896 2.030211 -1.189488 \n", "4 -1.178845 0.235273 -0.328574 ... 1.633662 -0.631118 1.814046 \n", "... ... ... ... ... ... ... ... \n", "99995 -1.031070 -0.620193 -0.283336 ... 0.356480 1.346821 -0.299087 \n", "99996 1.009458 0.860912 -0.191313 ... 0.044618 -2.030135 1.448640 \n", "99997 1.063016 -0.395696 -0.060886 ... 1.563345 -1.261853 -0.810122 \n", "99998 1.500870 0.501768 0.649079 ... 0.344663 1.734476 0.660177 \n", "99999 -0.101111 -0.956916 -0.251162 ... 2.250647 0.746250 -0.610311 \n", "\n", " x34 x35 x36 x37 x38 x39 target \n", "0 0.720823 -1.522117 -0.737707 -1.573936 -0.832174 0.203154 0 \n", "1 0.170156 -0.236553 -0.573761 -0.107317 -2.480178 0.420341 0 \n", "2 0.339090 -1.615997 -1.312444 0.984578 0.602858 -1.110684 1 \n", "3 -1.238600 1.658765 -0.255644 0.572194 0.195496 0.617734 1 \n", "4 1.031754 0.328665 1.704483 2.153710 -1.430552 -0.543915 1 \n", "... ... ... ... ... ... ... ... \n", "99995 2.343587 -2.003646 -0.933179 0.764255 -0.233526 -1.462311 1 \n", "99996 -0.854798 1.441451 1.347461 -0.937607 0.572504 -0.787673 0 \n", "99997 -0.503823 1.565602 -1.264792 -0.591644 1.588397 0.601721 0 \n", "99998 0.767554 1.461940 0.310189 -1.469978 0.900132 1.114330 0 \n", "99999 0.445467 -0.636288 -0.187444 0.226108 -0.186927 -1.024960 1 \n", "\n", "[100000 rows x 41 columns]" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "atom.dataset" ] }, { "cell_type": "code", "execution_count": 4, "id": "96f5ee24-5f1a-4293-98d1-a1896c7ad1e6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Scaler used: StandardScaler()\n", "Scaler's module: cuml._thirdparty.sklearn.preprocessing._data\n" ] } ], "source": [ "print(f\"Scaler used: {atom.standard}\")\n", "print(f\"Scaler's module: {atom.standard.__class__.__module__}\")" ] }, { "cell_type": "code", "execution_count": 5, "id": "a7c74dcf-a570-488a-bdcc-2d5bb234e887", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Training ========================= >>\n", "Models: RF, SGD, XGB\n", "Metric: f1\n", "\n", "\n", "Results for RandomForest:\n", "Fit ---------------------------------------------\n", "Train evaluation --> f1: 0.9726\n", "Test evaluation --> f1: 0.9431\n", "Time elapsed: 1.935s\n", "-------------------------------------------------\n", "Total time: 1.935s\n", "\n", "\n", "Results for StochasticGradientDescent:\n", "Fit ---------------------------------------------\n", "Train evaluation --> f1: 0.9236\n", "Test evaluation --> f1: 0.9219\n", "Time elapsed: 02m:16s\n", "-------------------------------------------------\n", "Total time: 02m:16s\n", "\n", "\n", "Results for XGBoost:\n", "Fit ---------------------------------------------\n", "Train evaluation --> f1: 0.9749\n", "Test evaluation --> f1: 0.9437\n", "Time elapsed: 6.394s\n", "-------------------------------------------------\n", "Total time: 6.394s\n", "\n", "\n", "Final results ==================== >>\n", "Total time: 02m:24s\n", "-------------------------------------\n", "RandomForest --> f1: 0.9431\n", "StochasticGradientDescent --> f1: 0.9219\n", "XGBoost --> f1: 0.9437 !\n" ] } ], "source": [ "atom.run(models=[\"RF\", \"SGD\", \"XGB\"])" ] }, { "cell_type": "code", "execution_count": 6, "id": "6a86a67e-9067-454f-bc92-dafb955e42bf", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
score_trainscore_testtime_fittime
RF0.97260.94311.9345121.934512
SGD0.92360.9219135.871493135.871493
XGB0.97490.94376.3944166.394416
\n", "
" ], "text/plain": [ " score_train score_test time_fit time\n", "RF 0.9726 0.9431 1.934512 1.934512\n", "SGD 0.9236 0.9219 135.871493 135.871493\n", "XGB 0.9749 0.9437 6.394416 6.394416" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "atom.results" ] }, { "cell_type": "code", "execution_count": 7, "id": "582646d5-49f9-451f-a85e-ea67c849e83d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "RF's module: cuml.ensemble.randomforestclassifier\n", "SGD's module: sklearn.linear_model._stochastic_gradient\n", "XGB's module: xgboost.sklearn\n" ] } ], "source": [ "for m in atom.models:\n", " print(f\"{m}'s module: {atom[m].estimator.__class__.__module__}\")" ] }, { "cell_type": "code", "execution_count": 8, "id": "8f893e9e-9702-45f4-b845-a60f777d3015", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
accuracyaverage_precisionbalanced_accuracyf1jaccardmatthews_corrcoefprecisionrecallroc_auc
RF0.94290.97410.94290.94310.89240.88580.93910.94720.9792
SGD0.92170.96350.92180.92190.85510.84350.92030.92350.9676
XGB0.94340.97530.94340.94370.89330.88680.93850.94890.9798
\n", "
" ], "text/plain": [ " accuracy average_precision balanced_accuracy f1 jaccard \\\n", "RF 0.9429 0.9741 0.9429 0.9431 0.8924 \n", "SGD 0.9217 0.9635 0.9218 0.9219 0.8551 \n", "XGB 0.9434 0.9753 0.9434 0.9437 0.8933 \n", "\n", " matthews_corrcoef precision recall roc_auc \n", "RF 0.8858 0.9391 0.9472 0.9792 \n", "SGD 0.8435 0.9203 0.9235 0.9676 \n", "XGB 0.8868 0.9385 0.9489 0.9798 " ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "atom.evaluate()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.2" } }, "nbformat": 4, "nbformat_minor": 5 }