{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Example: Data engines\n", "-----------------------\n", "\n", "This example shows how ATOM interacts with other data engines than pandas, for example [polars](https://pola.rs/).\n", "\n", "Import the breast cancer dataset from [sklearn.datasets](https://scikit-learn.org/stable/datasets/index.html#wine-dataset). This is a small and easy to train dataset whose goal is to predict whether a patient has breast cancer or not." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load the data" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# Import packages\n", "import polars as pl\n", "from sklearn.datasets import load_breast_cancer\n", "from atom import ATOMClassifier" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "shape: (5, 30)
mean radiusmean texturemean perimetermean areamean smoothnessmean compactnessmean concavitymean concave pointsmean symmetrymean fractal dimensionradius errortexture errorperimeter errorarea errorsmoothness errorcompactness errorconcavity errorconcave points errorsymmetry errorfractal dimension errorworst radiusworst textureworst perimeterworst areaworst smoothnessworst compactnessworst concavityworst concave pointsworst symmetryworst fractal dimension
f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64
17.9910.38122.81001.00.11840.27760.30010.14710.24190.078711.0950.90538.589153.40.0063990.049040.053730.015870.030030.00619325.3817.33184.62019.00.16220.66560.71190.26540.46010.1189
20.5717.77132.91326.00.084740.078640.08690.070170.18120.056670.54350.73393.39874.080.0052250.013080.01860.01340.013890.00353224.9923.41158.81956.00.12380.18660.24160.1860.2750.08902
19.6921.25130.01203.00.10960.15990.19740.12790.20690.059990.74560.78694.58594.030.006150.040060.038320.020580.02250.00457123.5725.53152.51709.00.14440.42450.45040.2430.36130.08758
11.4220.3877.58386.10.14250.28390.24140.10520.25970.097440.49561.1563.44527.230.009110.074580.056610.018670.059630.00920814.9126.598.87567.70.20980.86630.68690.25750.66380.173
20.2914.34135.11297.00.10030.13280.1980.10430.18090.058830.75720.78135.43894.440.011490.024610.056880.018850.017560.00511522.5416.67152.21575.00.13740.2050.40.16250.23640.07678
" ], "text/plain": [ "shape: (5, 30)\n", "┌─────────────┬──────────────┬────────────────┬───────────┬───┬─────────────────┬──────────────────────┬────────────────┬─────────────────────────┐\n", "│ mean radius ┆ mean texture ┆ mean perimeter ┆ mean area ┆ … ┆ worst concavity ┆ worst concave points ┆ worst symmetry ┆ worst fractal dimension │\n", "│ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ --- │\n", "│ f64 ┆ f64 ┆ f64 ┆ f64 ┆ ┆ f64 ┆ f64 ┆ f64 ┆ f64 │\n", "╞═════════════╪══════════════╪════════════════╪═══════════╪═══╪═════════════════╪══════════════════════╪════════════════╪═════════════════════════╡\n", "│ 17.99 ┆ 10.38 ┆ 122.8 ┆ 1001.0 ┆ … ┆ 0.7119 ┆ 0.2654 ┆ 0.4601 ┆ 0.1189 │\n", "│ 20.57 ┆ 17.77 ┆ 132.9 ┆ 1326.0 ┆ … ┆ 0.2416 ┆ 0.186 ┆ 0.275 ┆ 0.08902 │\n", "│ 19.69 ┆ 21.25 ┆ 130.0 ┆ 1203.0 ┆ … ┆ 0.4504 ┆ 0.243 ┆ 0.3613 ┆ 0.08758 │\n", "│ 11.42 ┆ 20.38 ┆ 77.58 ┆ 386.1 ┆ … ┆ 0.6869 ┆ 0.2575 ┆ 0.6638 ┆ 0.173 │\n", "│ 20.29 ┆ 14.34 ┆ 135.1 ┆ 1297.0 ┆ … ┆ 0.4 ┆ 0.1625 ┆ 0.2364 ┆ 0.07678 │\n", "└─────────────┴──────────────┴────────────────┴───────────┴───┴─────────────────┴──────────────────────┴────────────────┴─────────────────────────┘" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load the data and convert to polars for demonstration purposes\n", "X, y = load_breast_cancer(return_X_y=True, as_frame=True)\n", "\n", "X = pl.from_pandas(X)\n", "y = pl.from_pandas(y)\n", "\n", "X.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run the pipeline" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<< ================== ATOM ================== >>\n", "\n", "Configuration ==================== >>\n", "Algorithm task: Binary classification.\n", "Data engine: polars\n", "\n", "Dataset stats ==================== >>\n", "Shape: (569, 31)\n", "Train set size: 456\n", "Test set size: 113\n", "-------------------------------------\n", "Memory: 138.97 kB\n", "Scaled: False\n", "Outlier values: 167 (1.2%)\n", "\n" ] } ], "source": [ "# Specify the data engine in the constructor\n", "# Note that atom accepts any dataframe-like object to create the dataset\n", "atom = ATOMClassifier(X, y, engine=\"polars\", verbose=2, random_state=1)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "shape: (5, 30)
mean radiusmean texturemean perimetermean areamean smoothnessmean compactnessmean concavitymean concave pointsmean symmetrymean fractal dimensionradius errortexture errorperimeter errorarea errorsmoothness errorcompactness errorconcavity errorconcave points errorsymmetry errorfractal dimension errorworst radiusworst textureworst perimeterworst areaworst smoothnessworst compactnessworst concavityworst concave pointsworst symmetryworst fractal dimension
f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64
13.4820.8288.4559.20.10160.12550.10630.054390.1720.064190.2130.59141.54518.520.0053670.022390.030490.012620.013770.00318715.5326.02107.3740.40.1610.42250.5030.22580.28070.1071
18.3120.58120.81052.00.10680.12480.15690.094510.1860.059410.54490.92253.21867.360.0061760.018770.029130.010460.015590.00272521.8626.2142.21493.00.14920.25360.37590.1510.30740.07863
17.9324.48115.2998.90.088550.070270.056990.047440.15380.05510.42121.4332.76545.810.0054440.011690.016220.0085220.014190.00275120.9234.69135.11320.00.13150.18060.2080.11360.25040.07948
15.1329.8196.71719.50.08320.046050.046860.027390.18520.052940.46811.6273.04345.380.0068310.014270.024890.0090870.031510.0017517.2636.91110.1931.40.11480.098660.15470.065750.32330.06165
8.9515.7658.74245.20.094620.12430.092630.023080.13050.071630.31320.97893.2816.940.018350.06760.092630.023080.023840.0056019.41417.0763.34270.00.11790.18790.15440.038460.16520.07722
" ], "text/plain": [ "shape: (5, 30)\n", "┌─────────────┬──────────────┬────────────────┬───────────┬───┬─────────────────┬──────────────────────┬────────────────┬─────────────────────────┐\n", "│ mean radius ┆ mean texture ┆ mean perimeter ┆ mean area ┆ … ┆ worst concavity ┆ worst concave points ┆ worst symmetry ┆ worst fractal dimension │\n", "│ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ --- │\n", "│ f64 ┆ f64 ┆ f64 ┆ f64 ┆ ┆ f64 ┆ f64 ┆ f64 ┆ f64 │\n", "╞═════════════╪══════════════╪════════════════╪═══════════╪═══╪═════════════════╪══════════════════════╪════════════════╪═════════════════════════╡\n", "│ 13.48 ┆ 20.82 ┆ 88.4 ┆ 559.2 ┆ … ┆ 0.503 ┆ 0.2258 ┆ 0.2807 ┆ 0.1071 │\n", "│ 18.31 ┆ 20.58 ┆ 120.8 ┆ 1052.0 ┆ … ┆ 0.3759 ┆ 0.151 ┆ 0.3074 ┆ 0.07863 │\n", "│ 17.93 ┆ 24.48 ┆ 115.2 ┆ 998.9 ┆ … ┆ 0.208 ┆ 0.1136 ┆ 0.2504 ┆ 0.07948 │\n", "│ 15.13 ┆ 29.81 ┆ 96.71 ┆ 719.5 ┆ … ┆ 0.1547 ┆ 0.06575 ┆ 0.3233 ┆ 0.06165 │\n", "│ 8.95 ┆ 15.76 ┆ 58.74 ┆ 245.2 ┆ … ┆ 0.1544 ┆ 0.03846 ┆ 0.1652 ┆ 0.07722 │\n", "└─────────────┴──────────────┴────────────────┴───────────┴───┴─────────────────┴──────────────────────┴────────────────┴─────────────────────────┘" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# The data attributes return now polars types\n", "atom.X.head(5)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "shape: (5,)
target
i32
0
0
0
0
1
" ], "text/plain": [ "shape: (5,)\n", "Series: 'target' [i32]\n", "[\n", "\t0\n", "\t0\n", "\t0\n", "\t0\n", "\t1\n", "]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "atom.y.head(5)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Training ========================= >>\n", "Models: LR\n", "Metric: f1\n", "\n", "\n", "Results for LogisticRegression:\n", "Fit ---------------------------------------------\n", "Train evaluation --> f1: 0.9913\n", "Test evaluation --> f1: 0.9861\n", "Time elapsed: 0.129s\n", "-------------------------------------------------\n", "Time: 0.129s\n", "\n", "\n", "Final results ==================== >>\n", "Total time: 0.132s\n", "-------------------------------------\n", "LogisticRegression --> f1: 0.9861\n" ] } ], "source": [ "atom.run(\"LR\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Analyze the results" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "shape: (569,)
target
i64
0
0
0
0
0
0
0
0
0
0
0
0
1
1
1
1
1
0
0
0
0
0
0
1
" ], "text/plain": [ "shape: (569,)\n", "Series: 'target' [i64]\n", "[\n", "\t0\n", "\t0\n", "\t0\n", "\t0\n", "\t0\n", "\t0\n", "\t0\n", "\t0\n", "\t0\n", "\t0\n", "\t0\n", "\t0\n", "\t…\n", "\t1\n", "\t1\n", "\t1\n", "\t1\n", "\t1\n", "\t1\n", "\t0\n", "\t0\n", "\t0\n", "\t0\n", "\t0\n", "\t0\n", "\t1\n", "]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# The prediction methods also return types of the requested data engine\n", "atom.lr.predict(X)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0 0\n", "1 0\n", "2 0\n", "3 0\n", "4 0\n", "Name: target, dtype: int64[pyarrow]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "atom.lr.engine = \"pandas-pyarrow\"\n", "atom.lr.predict(X.head(5))" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Dask Series Structure:\n", "npartitions=1\n", "0 int64\n", "4 ...\n", "Name: target, dtype: int64\n", "Dask Name: from_pandas, 1 graph layer" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "atom.lr.engine = \"dask\"\n", "atom.lr.predict(X.head(5))" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", "[\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0\n", "]" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "atom.lr.engine = \"pyarrow\"\n", "atom.lr.predict(X.head(5))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.2" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 4 }