{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Example: Binary classification\n", "--------------------------------\n", "\n", "This example shows how to use ATOM to solve a binary classification problem. Additonnaly, we'll perform a variety of data cleaning steps to prepare the data for modeling.\n", "\n", "The data used is a variation on the [Australian weather dataset](https://www.kaggle.com/jsphyg/weather-dataset-rattle-package) from Kaggle. You can download it from [here](https://github.com/tvdboom/ATOM/blob/master/examples/datasets/weatherAUS.csv). The goal of this dataset is to predict whether or not it will rain tomorrow training a binary classifier on target `RainTomorrow`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load the data" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "tags": [] }, "outputs": [], "source": [ "# Import packages\n", "import pandas as pd\n", "from atom import ATOMClassifier" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
LocationMinTempMaxTempRainfallEvaporationSunshineWindGustDirWindGustSpeedWindDir9amWindDir3pm...Humidity9amHumidity3pmPressure9amPressure3pmCloud9amCloud3pmTemp9amTemp3pmRainTodayRainTomorrow
0MelbourneAirport18.026.921.47.08.9SSE41.0WSSE...95.054.01019.51017.08.05.018.526.0Yes0
1Adelaide17.223.40.0NaNNaNS41.0SWSW...59.036.01015.71015.7NaNNaN17.721.9No0
2Cairns18.624.67.43.06.1SSE54.0SSESE...78.057.01018.71016.63.03.020.824.1Yes0
3Portland13.616.84.21.20.0ESE39.0ESEESE...76.074.01021.41020.57.08.015.616.0Yes1
4Walpole16.419.90.0NaNNaNSE44.0SESE...78.070.01019.41018.9NaNNaN17.418.1No0
\n", "

5 rows × 22 columns

\n", "
" ], "text/plain": [ " Location MinTemp MaxTemp Rainfall Evaporation Sunshine \\\n", "0 MelbourneAirport 18.0 26.9 21.4 7.0 8.9 \n", "1 Adelaide 17.2 23.4 0.0 NaN NaN \n", "2 Cairns 18.6 24.6 7.4 3.0 6.1 \n", "3 Portland 13.6 16.8 4.2 1.2 0.0 \n", "4 Walpole 16.4 19.9 0.0 NaN NaN \n", "\n", " WindGustDir WindGustSpeed WindDir9am WindDir3pm ... Humidity9am \\\n", "0 SSE 41.0 W SSE ... 95.0 \n", "1 S 41.0 S WSW ... 59.0 \n", "2 SSE 54.0 SSE SE ... 78.0 \n", "3 ESE 39.0 ESE ESE ... 76.0 \n", "4 SE 44.0 SE SE ... 78.0 \n", "\n", " Humidity3pm Pressure9am Pressure3pm Cloud9am Cloud3pm Temp9am \\\n", "0 54.0 1019.5 1017.0 8.0 5.0 18.5 \n", "1 36.0 1015.7 1015.7 NaN NaN 17.7 \n", "2 57.0 1018.7 1016.6 3.0 3.0 20.8 \n", "3 74.0 1021.4 1020.5 7.0 8.0 15.6 \n", "4 70.0 1019.4 1018.9 NaN NaN 17.4 \n", "\n", " Temp3pm RainToday RainTomorrow \n", "0 26.0 Yes 0 \n", "1 21.9 No 0 \n", "2 24.1 Yes 0 \n", "3 16.0 Yes 1 \n", "4 18.1 No 0 \n", "\n", "[5 rows x 22 columns]" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load data\n", "X = pd.read_csv(\"./datasets/weatherAUS.csv\")\n", "\n", "# Let's have a look\n", "X.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run the pipeline" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<< ================== ATOM ================== >>\n", "\n", "Configuration ==================== >>\n", "Algorithm task: Binary classification.\n", "Parallel processing with 8 cores.\n", "\n", "Dataset stats ==================== >>\n", "Shape: (7109, 22)\n", "Train set size: 5688\n", "Test set size: 1421\n", "-------------------------------------\n", "Memory: 1.25 MB\n", "Scaled: False\n", "Missing values: 15772 (10.1%)\n", "Categorical features: 5 (23.8%)\n", "\n" ] } ], "source": [ "# Call atom using only 5% of the complete dataset (for explanatory purposes)\n", "atom = ATOMClassifier(X, y=\"RainTomorrow\", n_rows=0.05, n_jobs=8, verbose=2)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Fitting Imputer...\n", "Imputing missing values...\n", " --> Dropping 3 samples for containing more than 16 missing values.\n", " --> Dropping 896 samples for containing missing values in categorical columns.\n", " --> Imputing 10 missing values with median (12.2) in column MinTemp.\n", " --> Imputing 7 missing values with median (22.9) in column MaxTemp.\n", " --> Imputing 2560 missing values with median (4.8) in column Evaporation.\n", " --> Imputing 2841 missing values with median (8.4) in column Sunshine.\n", " --> Imputing 45 missing values with median (70.0) in column Humidity9am.\n", " --> Imputing 76 missing values with median (52.0) in column Humidity3pm.\n", " --> Imputing 470 missing values with median (1017.4) in column Pressure9am.\n", " --> Imputing 468 missing values with median (1015.0) in column Pressure3pm.\n", " --> Imputing 2346 missing values with median (5.0) in column Cloud9am.\n", " --> Imputing 2459 missing values with median (5.0) in column Cloud3pm.\n", " --> Imputing 16 missing values with median (16.95) in column Temp9am.\n", " --> Imputing 41 missing values with median (21.3) in column Temp3pm.\n" ] } ], "source": [ "# Impute missing values\n", "atom.impute(strat_num=\"median\", strat_cat=\"drop\", max_nan_rows=0.8)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Fitting Encoder...\n", "Encoding categorical columns...\n", " --> Target-encoding feature Location. Contains 1 classes.\n", " --> Handling 6210 unknown classes.\n", " --> Target-encoding feature WindGustDir. Contains 16 classes.\n", " --> Target-encoding feature WindDir9am. Contains 16 classes.\n", " --> Target-encoding feature WindDir3pm. Contains 16 classes.\n", " --> Ordinal-encoding feature RainToday. Contains 2 classes.\n" ] } ], "source": [ "# Encode the categorical features\n", "atom.encode(strategy=\"Target\", max_onehot=10, infrequent_to_value=0.04)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Training ========================= >>\n", "Models: ET, RF\n", "Metric: f1\n", "\n", "\n", "Results for ExtraTrees:\n", "Fit ---------------------------------------------\n", "Train evaluation --> f1: 1.0\n", "Test evaluation --> f1: 0.5349\n", "Time elapsed: 0.273s\n", "Bootstrap ---------------------------------------\n", "Evaluation --> f1: 0.5655 ± 0.0068\n", "Time elapsed: 0.902s\n", "-------------------------------------------------\n", "Time: 1.175s\n", "\n", "\n", "Results for RandomForest:\n", "Fit ---------------------------------------------\n", "Train evaluation --> f1: 1.0\n", "Test evaluation --> f1: 0.5714\n", "Time elapsed: 0.261s\n", "Bootstrap ---------------------------------------\n", "Evaluation --> f1: 0.5718 ± 0.0131\n", "Time elapsed: 1.012s\n", "-------------------------------------------------\n", "Time: 1.273s\n", "\n", "\n", "Final results ==================== >>\n", "Total time: 2.458s\n", "-------------------------------------\n", "ExtraTrees --> f1: 0.5655 ± 0.0068 ~\n", "RandomForest --> f1: 0.5718 ± 0.0131 ~ !\n" ] } ], "source": [ "# Train an Extra-Trees and a Random Forest model\n", "atom.run(models=[\"ET\", \"RF\"], metric=\"f1\", n_bootstrap=5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Analyze the results" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
f1_trainf1_testtime_fitf1_bootstraptime_bootstraptime
ET0.85660.56150.2731680.5655380.9023191.175487
RF1.00000.57140.2612750.5718241.0117311.273006
\n", "
" ], "text/plain": [ " f1_train f1_test time_fit f1_bootstrap time_bootstrap time\n", "ET 0.8566 0.5615 0.273168 0.565538 0.902319 1.175487\n", "RF 1.0000 0.5714 0.261275 0.571824 1.011731 1.273006" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Let's have a look at the final results\n", "atom.results" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "application/vnd.plotly.v1+json": { "config": { "plotlyServerURL": "https://plotly.com" }, "data": [ { "hovertemplate": "%{x}", "legendgroup": "f1", "marker": { "color": "rgba(0, 98, 98, 0.2)", "line": { "color": "rgb(0, 98, 98)", "width": 2 } }, "name": "f1", "orientation": "h", "showlegend": true, "type": "bar", "x": [ 0.5615, 0.5714 ], "xaxis": "x", "y": [ "ET", "RF" ], "yaxis": "y" } ], "layout": { "bargroupgap": 0.05, "font": { "size": 12 }, "height": 500, "hoverlabel": { "font": { "size": 16 } }, "legend": { "bgcolor": "rgba(255, 255, 255, 0.5)", "font": { "size": 16 }, "groupclick": "toggleitem", "traceorder": "grouped", "x": 0.99, "xanchor": "right", "y": 0.01, "yanchor": "bottom" }, "margin": { "b": 50, "l": 50, "pad": 0, "r": 0, "t": 49 }, "showlegend": true, "template": { "data": { "bar": [ { "error_x": { "color": "#2a3f5f" }, "error_y": { "color": "#2a3f5f" }, "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "bar" } ], "barpolar": [ { "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "barpolar" } ], "carpet": [ { "aaxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "baxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "type": "carpet" } ], "choropleth": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "choropleth" } ], "contour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "contour" } ], "contourcarpet": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "contourcarpet" } ], "heatmap": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmap" } ], "heatmapgl": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmapgl" } ], "histogram": [ { "marker": { "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "histogram" } ], "histogram2d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2d" } ], "histogram2dcontour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2dcontour" } ], "mesh3d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "mesh3d" } ], "parcoords": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "parcoords" } ], "pie": [ { "automargin": true, "type": "pie" } ], "scatter": [ { "fillpattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 }, "type": "scatter" } ], "scatter3d": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatter3d" } ], "scattercarpet": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattercarpet" } ], "scattergeo": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergeo" } ], "scattergl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergl" } ], "scattermapbox": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattermapbox" } ], "scatterpolar": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolar" } ], "scatterpolargl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolargl" } ], "scatterternary": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterternary" } ], "surface": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "surface" } ], "table": [ { "cells": { "fill": { "color": "#EBF0F8" }, "line": { "color": "white" } }, "header": { "fill": { "color": "#C8D4E3" }, "line": { "color": "white" } }, "type": "table" } ] }, "layout": { "annotationdefaults": { "arrowcolor": "#2a3f5f", "arrowhead": 0, "arrowwidth": 1 }, "autotypenumbers": "strict", "coloraxis": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "colorscale": { "diverging": [ [ 0, "#8e0152" ], [ 0.1, "#c51b7d" ], [ 0.2, "#de77ae" ], [ 0.3, "#f1b6da" ], [ 0.4, "#fde0ef" ], [ 0.5, "#f7f7f7" ], [ 0.6, "#e6f5d0" ], [ 0.7, "#b8e186" ], [ 0.8, "#7fbc41" ], [ 0.9, "#4d9221" ], [ 1, "#276419" ] ], "sequential": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "sequentialminus": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ] }, "colorway": [ "#636efa", "#EF553B", "#00cc96", "#ab63fa", "#FFA15A", "#19d3f3", "#FF6692", "#B6E880", "#FF97FF", "#FECB52" ], "font": { "color": "#2a3f5f" }, "geo": { "bgcolor": "white", "lakecolor": "white", "landcolor": "#E5ECF6", "showlakes": true, "showland": true, "subunitcolor": "white" }, "hoverlabel": { "align": "left" }, "hovermode": "closest", "mapbox": { "style": "light" }, "paper_bgcolor": "white", "plot_bgcolor": "#E5ECF6", "polar": { "angularaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "radialaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "scene": { "xaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "yaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "zaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" } }, "shapedefaults": { "line": { "color": "#2a3f5f" } }, "ternary": { "aaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "baxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "caxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "title": { "x": 0.05 }, "xaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 }, "yaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 } } }, "title": { "font": { "size": 24 }, "pad": { "b": 15, "t": 15 }, "text": "RF vs ET performance", "x": 0.5, "xanchor": "center", "xref": "paper", "y": 1, "yanchor": "top" }, "width": 900, "xaxis": { "anchor": "y", "automargin": true, "autorange": true, "domain": [ 0, 1 ], "range": [ 0, 0.6014736842105264 ], "title": { "font": { "size": 16 }, "text": "score" }, "type": "linear" }, "yaxis": { "anchor": "x", "automargin": true, "autorange": true, "categoryorder": "total ascending", "domain": [ 0, 1 ], "range": [ -0.5, 1.5 ], "title": { "font": { "size": 16 } }, "type": "category" } } }, "text/html": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Visualize the bootstrap results\n", "atom.plot_results(title=\"RF vs ET performance\")" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
 accuracyapbaf1jaccardmccprecisionrecallauc
ET0.8466000.6733000.7036000.5615000.3903000.4978000.7469000.4498000.851100
RF0.8450000.6747000.7146000.5765000.4050000.4998000.7143000.4833000.847100
\n" ], "text/plain": [ "" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Print the results of some common metrics\n", "atom.evaluate()" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The winner is the RF model!!\n" ] } ], "source": [ "# The winner attribute calls the best model (atom.winner == atom.rf)\n", "print(f\"The winner is the {atom.winner.name} model!!\")" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "application/vnd.plotly.v1+json": { "config": { "plotlyServerURL": "https://plotly.com" }, "data": [ { "fill": "tonexty", "fillcolor": "rgba(0, 98, 98, 0.2)", "fillpattern": { "shape": "" }, "legendgroup": "RF", "legendgrouptitle": { "font": { "size": 16 }, "text": "RF" }, "line": { "color": "rgb(0, 98, 98)", "dash": "solid", "width": 2 }, "mode": "lines", "name": "RainTomorrow=0", "showlegend": true, "type": "scatter", "x": [ 0, 0.010101010101010102, 0.020202020202020204, 0.030303030303030304, 0.04040404040404041, 0.05050505050505051, 0.06060606060606061, 0.07070707070707072, 0.08080808080808081, 0.09090909090909091, 0.10101010101010102, 0.11111111111111112, 0.12121212121212122, 0.13131313131313133, 0.14141414141414144, 0.15151515151515152, 0.16161616161616163, 0.17171717171717174, 0.18181818181818182, 0.19191919191919193, 0.20202020202020204, 0.21212121212121213, 0.22222222222222224, 0.23232323232323235, 0.24242424242424243, 0.25252525252525254, 0.26262626262626265, 0.27272727272727276, 0.2828282828282829, 0.29292929292929293, 0.30303030303030304, 0.31313131313131315, 0.32323232323232326, 0.33333333333333337, 0.3434343434343435, 0.3535353535353536, 0.36363636363636365, 0.37373737373737376, 0.38383838383838387, 0.393939393939394, 0.4040404040404041, 0.4141414141414142, 0.42424242424242425, 0.43434343434343436, 0.4444444444444445, 0.4545454545454546, 0.4646464646464647, 0.4747474747474748, 0.48484848484848486, 0.494949494949495, 0.5050505050505051, 0.5151515151515152, 0.5252525252525253, 0.5353535353535354, 0.5454545454545455, 0.5555555555555556, 0.5656565656565657, 0.5757575757575758, 0.5858585858585859, 0.595959595959596, 0.6060606060606061, 0.6161616161616162, 0.6262626262626263, 0.6363636363636365, 0.6464646464646465, 0.6565656565656566, 0.6666666666666667, 0.6767676767676768, 0.686868686868687, 0.696969696969697, 0.7070707070707072, 0.7171717171717172, 0.7272727272727273, 0.7373737373737375, 0.7474747474747475, 0.7575757575757577, 0.7676767676767677, 0.7777777777777778, 0.787878787878788, 0.797979797979798, 0.8080808080808082, 0.8181818181818182, 0.8282828282828284, 0.8383838383838385, 0.8484848484848485, 0.8585858585858587, 0.8686868686868687, 0.8787878787878789, 0.888888888888889, 0.8989898989898991, 0.9090909090909092, 0.9191919191919192, 0.9292929292929294, 0.9393939393939394, 0.9494949494949496, 0.9595959595959597, 0.9696969696969697, 0.9797979797979799, 0.98989898989899, 1 ], "xaxis": "x", "y": [ 3.086090078263203, 3.5456060539414302, 3.903227598102875, 4.131990263324448, 4.224204462590838, 4.191222768462355, 4.058748699426121, 3.8595927223807016, 3.626294531361786, 3.3855770546089854, 3.1555591835071506, 2.945584685324515, 2.757843818826175, 2.5897950151306417, 2.4365969263826397, 2.293092437069928, 2.1551596237579296, 2.0204040711477687, 1.8882549139057827, 1.759600865825411, 1.6361718886687255, 1.519898921624409, 1.4124319161956076, 1.3148786478334256, 1.2277084901511717, 1.150717579548993, 1.082996847661401, 1.0229374771053437, 0.968370185294651, 0.9169092673201129, 0.8664663328811567, 0.8157770207448276, 0.7647266802385058, 0.7143126944750913, 0.6662241576244406, 0.6221844590520338, 0.5833064252605725, 0.549702132779791, 0.5204770683164806, 0.4940790662629514, 0.46884201797940167, 0.44351540558918684, 0.4176099749575049, 0.39148341884456667, 0.36618842774897714, 0.34317012131955504, 0.3239159400713598, 0.30963743081387246, 0.30102390355979997, 0.29807881880546605, 0.3000450915053438, 0.30543807362770803, 0.31221088774460337, 0.3180547293426855, 0.32078849996664877, 0.3187452796131339, 0.31105095617417466, 0.2977247846057101, 0.27959359483251367, 0.2580662751112952, 0.23483947142947, 0.2116001357610296, 0.18977270960614026, 0.17034294103334513, 0.15377839793640108, 0.14005135951583936, 0.12874967388610778, 0.11924050062636579, 0.11084025599854261, 0.10294814429005092, 0.09511945730568779, 0.08708037366461188, 0.07870697687724516, 0.06999807859790123, 0.06106141389939897, 0.052112403994384604, 0.04346702011642739, 0.03550714799588599, 0.02861090214595683, 0.02306301185105863, 0.0189775043006975, 0.01626569596231346, 0.014665840743712973, 0.013825007564039898, 0.013401499302922374, 0.01314787364789969, 0.012944152436654298, 0.012773727668138633, 0.012660177984191154, 0.012599580824937243, 0.012522178903274993, 0.012300206153013719, 0.011794468754031406, 0.010913592959836186, 0.009655843847734056, 0.008114762366543495, 0.006448977456065783, 0.004832650255179074, 0.003408566288940023, 0.002260276583995862 ], "yaxis": "y" }, { "fill": "tonexty", "fillcolor": "rgba(0, 98, 98, 0.2)", "fillpattern": { "shape": "/" }, "legendgroup": "RF", "legendgrouptitle": { "font": { "size": 16 }, "text": "RF" }, "line": { "color": "rgb(0, 98, 98)", "dash": "dashdot", "width": 2 }, "mode": "lines", "name": "RainTomorrow=1", "showlegend": true, "type": "scatter", "x": [ 0, 0.010101010101010102, 0.020202020202020204, 0.030303030303030304, 0.04040404040404041, 0.05050505050505051, 0.06060606060606061, 0.07070707070707072, 0.08080808080808081, 0.09090909090909091, 0.10101010101010102, 0.11111111111111112, 0.12121212121212122, 0.13131313131313133, 0.14141414141414144, 0.15151515151515152, 0.16161616161616163, 0.17171717171717174, 0.18181818181818182, 0.19191919191919193, 0.20202020202020204, 0.21212121212121213, 0.22222222222222224, 0.23232323232323235, 0.24242424242424243, 0.25252525252525254, 0.26262626262626265, 0.27272727272727276, 0.2828282828282829, 0.29292929292929293, 0.30303030303030304, 0.31313131313131315, 0.32323232323232326, 0.33333333333333337, 0.3434343434343435, 0.3535353535353536, 0.36363636363636365, 0.37373737373737376, 0.38383838383838387, 0.393939393939394, 0.4040404040404041, 0.4141414141414142, 0.42424242424242425, 0.43434343434343436, 0.4444444444444445, 0.4545454545454546, 0.4646464646464647, 0.4747474747474748, 0.48484848484848486, 0.494949494949495, 0.5050505050505051, 0.5151515151515152, 0.5252525252525253, 0.5353535353535354, 0.5454545454545455, 0.5555555555555556, 0.5656565656565657, 0.5757575757575758, 0.5858585858585859, 0.595959595959596, 0.6060606060606061, 0.6161616161616162, 0.6262626262626263, 0.6363636363636365, 0.6464646464646465, 0.6565656565656566, 0.6666666666666667, 0.6767676767676768, 0.686868686868687, 0.696969696969697, 0.7070707070707072, 0.7171717171717172, 0.7272727272727273, 0.7373737373737375, 0.7474747474747475, 0.7575757575757577, 0.7676767676767677, 0.7777777777777778, 0.787878787878788, 0.797979797979798, 0.8080808080808082, 0.8181818181818182, 0.8282828282828284, 0.8383838383838385, 0.8484848484848485, 0.8585858585858587, 0.8686868686868687, 0.8787878787878789, 0.888888888888889, 0.8989898989898991, 0.9090909090909092, 0.9191919191919192, 0.9292929292929294, 0.9393939393939394, 0.9494949494949496, 0.9595959595959597, 0.9696969696969697, 0.9797979797979799, 0.98989898989899, 1 ], "xaxis": "x", "y": [ 0.49460059636526654, 0.5437299765635728, 0.5924549715227454, 0.6401268643625008, 0.6861590145034507, 0.7300513898235731, 0.7714077909146727, 0.8099444393111012, 0.8454895276852679, 0.8779743156360613, 0.9074172723135195, 0.933903494505514, 0.9575620674335061, 0.9785441259276458, 0.9970041053149281, 1.013086084987271, 1.02691630971778, 1.0386020439115358, 1.0482360067450938, 1.0559048809144613, 1.061699889020974, 1.0657272536866986, 1.0681165150857943, 1.0690251368100538, 1.0686385079720726, 1.0671652357327344, 1.0648283948447264, 1.061854042743688, 1.0584587278292605, 1.0548378598966015, 1.0511566644255648, 1.0475450395816779, 1.0440970458211427, 1.0408750764226138, 1.0379180851360674, 1.0352526792878727, 1.0329054972836857, 1.0309151227976001, 1.0293418545992685, 1.0282739294398255, 1.027829238363471, 1.0281521204187258, 1.0294053921915476, 1.0317583108019557, 1.0353716172187881, 1.0403811269566785, 1.0468815045329305, 1.0549118704037497, 1.0644447513042488, 1.07537961335603, 1.0875418350329296, 1.1006875120772692, 1.114513971075819, 1.1286753394580675, 1.142802018511987, 1.1565224775400516, 1.1694854773662722, 1.1813806823259734, 1.1919556642104079, 1.2010275556049437, 1.2084880675849956, 1.2143012159007327, 1.2184938428812089, 1.2211398002465637, 1.222339379113234, 1.2221961453715042, 1.2207936819124179, 1.218174800253226, 1.2143255448320802, 1.2091657956850377, 1.2025475401625636, 1.1942610235451034, 1.1840481117087707, 1.1716214182559084, 1.1566871612077547, 1.13896938927523, 1.1182331855367673, 1.0943047062066922, 1.0670863949583325, 1.0365663506872762, 1.002821524414024, 0.966015083969502, 0.9263888310453431, 0.8842519261834968, 0.8399673459972671, 0.7939374677203351, 0.7465899811557367, 0.6983650191449953, 0.649704036350243, 0.6010406131005545, 0.5527930669040159, 0.505358553057559, 0.4591082421269574, 0.4143831713329451, 0.3714904587694544, 0.33069971331496967, 0.29223963451296087, 0.25629494308974676, 0.22300388857319367, 0.19245663011004846 ], "yaxis": "y" } ], "layout": { "font": { "size": 12 }, "height": 600, "hoverlabel": { "font": { "size": 16 } }, "legend": { "bgcolor": "rgba(255, 255, 255, 0.5)", "font": { "size": 16 }, "groupclick": "toggleitem", "traceorder": "grouped", "x": 0.99, "xanchor": "right", "y": 0.99, "yanchor": "top" }, "margin": { "b": 50, "l": 50, "pad": 0, "r": 0, "t": 35 }, "showlegend": true, "template": { "data": { "bar": [ { "error_x": { "color": "#2a3f5f" }, "error_y": { "color": "#2a3f5f" }, "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "bar" } ], "barpolar": [ { "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "barpolar" } ], "carpet": [ { "aaxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "baxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "type": "carpet" } ], "choropleth": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "choropleth" } ], "contour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "contour" } ], "contourcarpet": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "contourcarpet" } ], "heatmap": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmap" } ], "heatmapgl": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmapgl" } ], "histogram": [ { "marker": { "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "histogram" } ], "histogram2d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2d" } ], "histogram2dcontour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2dcontour" } ], "mesh3d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "mesh3d" } ], "parcoords": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "parcoords" } ], "pie": [ { "automargin": true, "type": "pie" } ], "scatter": [ { "fillpattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 }, "type": "scatter" } ], "scatter3d": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatter3d" } ], "scattercarpet": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattercarpet" } ], "scattergeo": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergeo" } ], "scattergl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergl" } ], "scattermapbox": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattermapbox" } ], "scatterpolar": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolar" } ], "scatterpolargl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolargl" } ], "scatterternary": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterternary" } ], "surface": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "surface" } ], "table": [ { "cells": { "fill": { "color": "#EBF0F8" }, "line": { "color": "white" } }, "header": { "fill": { "color": "#C8D4E3" }, "line": { "color": "white" } }, "type": "table" } ] }, "layout": { "annotationdefaults": { "arrowcolor": "#2a3f5f", "arrowhead": 0, "arrowwidth": 1 }, "autotypenumbers": "strict", "coloraxis": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "colorscale": { "diverging": [ [ 0, "#8e0152" ], [ 0.1, "#c51b7d" ], [ 0.2, "#de77ae" ], [ 0.3, "#f1b6da" ], [ 0.4, "#fde0ef" ], [ 0.5, "#f7f7f7" ], [ 0.6, "#e6f5d0" ], [ 0.7, "#b8e186" ], [ 0.8, "#7fbc41" ], [ 0.9, "#4d9221" ], [ 1, "#276419" ] ], "sequential": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "sequentialminus": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ] }, "colorway": [ "#636efa", "#EF553B", "#00cc96", "#ab63fa", "#FFA15A", "#19d3f3", "#FF6692", "#B6E880", "#FF97FF", "#FECB52" ], "font": { "color": "#2a3f5f" }, "geo": { "bgcolor": "white", "lakecolor": "white", "landcolor": "#E5ECF6", "showlakes": true, "showland": true, "subunitcolor": "white" }, "hoverlabel": { "align": "left" }, "hovermode": "closest", "mapbox": { "style": "light" }, "paper_bgcolor": "white", "plot_bgcolor": "#E5ECF6", "polar": { "angularaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "radialaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "scene": { "xaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "yaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "zaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" } }, "shapedefaults": { "line": { "color": "#2a3f5f" } }, "ternary": { "aaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "baxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "caxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "title": { "x": 0.05 }, "xaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 }, "yaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 } } }, "title": { "font": { "size": 24 }, "pad": { "b": 15, "t": 15 }, "x": 0.5, "xanchor": "center", "xref": "paper", "y": 1, "yanchor": "top" }, "width": 900, "xaxis": { "anchor": "y", "automargin": true, "domain": [ 0, 1 ], "range": [ 0, 1 ], "title": { "font": { "size": 16 }, "text": "Probability" }, "type": "linear" }, "yaxis": { "anchor": "x", "automargin": true, "autorange": true, "domain": [ 0, 1 ], "range": [ -0.031529360583328725, 4.448190453284215 ], "title": { "font": { "size": 16 }, "text": "Probability density" }, "type": "linear" } } }, "text/html": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Visualize the distribution of predicted probabilities\n", "atom.winner.plot_probabilities()" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "application/vnd.plotly.v1+json": { "config": { "plotlyServerURL": "https://plotly.com" }, "data": [ { "hovertemplate": "(%{x}, %{y})RF - f1_score", "legendgroup": "RF", "legendgrouptitle": { "font": { "size": 16 }, "text": "RF" }, "line": { "color": "rgb(0, 98, 98)", "dash": "solid", "width": 2 }, "marker": { "color": "rgb(0, 98, 98)", "line": { "color": "rgba(255, 255, 255, 0.9)", "width": 1 }, "size": 8, "symbol": "circle" }, "mode": "lines", "name": "f1_score", "showlegend": true, "type": "scatter", "x": [ 0, 0.02040816326530612, 0.04081632653061224, 0.061224489795918366, 0.08163265306122448, 0.1020408163265306, 0.12244897959183673, 0.14285714285714285, 0.16326530612244897, 0.18367346938775508, 0.2040816326530612, 0.22448979591836732, 0.24489795918367346, 0.26530612244897955, 0.2857142857142857, 0.3061224489795918, 0.32653061224489793, 0.3469387755102041, 0.36734693877551017, 0.3877551020408163, 0.4081632653061224, 0.42857142857142855, 0.44897959183673464, 0.4693877551020408, 0.4897959183673469, 0.5102040816326531, 0.5306122448979591, 0.5510204081632653, 0.5714285714285714, 0.5918367346938775, 0.6122448979591836, 0.6326530612244897, 0.6530612244897959, 0.673469387755102, 0.6938775510204082, 0.7142857142857142, 0.7346938775510203, 0.7551020408163265, 0.7755102040816326, 0.7959183673469387, 0.8163265306122448, 0.836734693877551, 0.8571428571428571, 0.8775510204081632, 0.8979591836734693, 0.9183673469387754, 0.9387755102040816, 0.9591836734693877, 0.9795918367346939, 1 ], "xaxis": "x", "y": [ 0.35842771485676217, 0.4006069802731411, 0.43478260869565216, 0.45330915684496825, 0.4724104549854792, 0.4963805584281282, 0.514161220043573, 0.5341040462427745, 0.5525672371638142, 0.5755208333333334, 0.5850340136054422, 0.5833333333333334, 0.5963855421686747, 0.6003110419906688, 0.605475040257649, 0.6197654941373534, 0.6236933797909407, 0.6306306306306306, 0.6236162361623616, 0.6133333333333333, 0.6126482213438735, 0.6060606060606061, 0.606694560669456, 0.5944798301486199, 0.5869565217391305, 0.5695067264573991, 0.5596330275229358, 0.5515587529976019, 0.5356265356265356, 0.5204081632653061, 0.4827586206896552, 0.4636118598382749, 0.4376731301939058, 0.4057142857142857, 0.3988439306358382, 0.3775811209439528, 0.3393939393939394, 0.2946708463949843, 0.25161290322580643, 0.21122112211221122, 0.16891891891891891, 0.14383561643835616, 0.11188811188811189, 0.07829181494661921, 0.05755395683453238, 0.043478260869565216, 0.014760147601476014, 0.007407407407407408, 0, 0 ], "yaxis": "y" }, { "hovertemplate": "(%{x}, %{y})RF - accuracy_score", "legendgroup": "RF", "legendgrouptitle": { "font": { "size": 16 }, "text": "RF" }, "line": { "color": "rgb(0, 98, 98)", "dash": "dashdot", "width": 2 }, "marker": { "color": "rgb(0, 98, 98)", "line": { "color": "rgba(255, 255, 255, 0.9)", "width": 1 }, "size": 8, "symbol": "x" }, "mode": "lines", "name": "accuracy_score", "showlegend": true, "type": "scatter", "x": [ 0, 0.02040816326530612, 0.04081632653061224, 0.061224489795918366, 0.08163265306122448, 0.1020408163265306, 0.12244897959183673, 0.14285714285714285, 0.16326530612244897, 0.18367346938775508, 0.2040816326530612, 0.22448979591836732, 0.24489795918367346, 0.26530612244897955, 0.2857142857142857, 0.3061224489795918, 0.32653061224489793, 0.3469387755102041, 0.36734693877551017, 0.3877551020408163, 0.4081632653061224, 0.42857142857142855, 0.44897959183673464, 0.4693877551020408, 0.4897959183673469, 0.5102040816326531, 0.5306122448979591, 0.5510204081632653, 0.5714285714285714, 0.5918367346938775, 0.6122448979591836, 0.6326530612244897, 0.6530612244897959, 0.673469387755102, 0.6938775510204082, 0.7142857142857142, 0.7346938775510203, 0.7551020408163265, 0.7755102040816326, 0.7959183673469387, 0.8163265306122448, 0.836734693877551, 0.8571428571428571, 0.8775510204081632, 0.8979591836734693, 0.9183673469387754, 0.9387755102040816, 0.9591836734693877, 0.9795918367346939, 1 ], "xaxis": "x", "y": [ 0.21834415584415584, 0.3587662337662338, 0.4512987012987013, 0.510551948051948, 0.5576298701298701, 0.6047077922077922, 0.637987012987013, 0.6728896103896104, 0.702922077922078, 0.7353896103896104, 0.752435064935065, 0.7646103896103896, 0.7824675324675324, 0.7913961038961039, 0.8011363636363636, 0.8157467532467533, 0.8246753246753247, 0.8336038961038961, 0.8344155844155844, 0.8352272727272727, 0.8409090909090909, 0.8417207792207793, 0.8474025974025974, 0.8449675324675324, 0.8457792207792207, 0.8441558441558441, 0.8441558441558441, 0.8482142857142857, 0.8465909090909091, 0.8474025974025974, 0.8417207792207793, 0.838474025974026, 0.8352272727272727, 0.8311688311688312, 0.8311688311688312, 0.8287337662337663, 0.823051948051948, 0.8173701298701299, 0.8116883116883117, 0.8060064935064936, 0.8003246753246753, 0.797077922077922, 0.7938311688311688, 0.7897727272727273, 0.7873376623376623, 0.7857142857142857, 0.7832792207792207, 0.7824675324675324, 0.7816558441558441, 0.7816558441558441 ], "yaxis": "y" }, { "hovertemplate": "(%{x}, %{y})RF - average_precision_score", "legendgroup": "RF", "legendgrouptitle": { "font": { "size": 16 }, "text": "RF" }, "line": { "color": "rgb(0, 98, 98)", "dash": "dash", "width": 2 }, "marker": { "color": "rgb(0, 98, 98)", "line": { "color": "rgba(255, 255, 255, 0.9)", "width": 1 }, "size": 8, "symbol": "diamond" }, "mode": "lines", "name": "average_precision_score", "showlegend": true, "type": "scatter", "x": [ 0, 0.02040816326530612, 0.04081632653061224, 0.061224489795918366, 0.08163265306122448, 0.1020408163265306, 0.12244897959183673, 0.14285714285714285, 0.16326530612244897, 0.18367346938775508, 0.2040816326530612, 0.22448979591836732, 0.24489795918367346, 0.26530612244897955, 0.2857142857142857, 0.3061224489795918, 0.32653061224489793, 0.3469387755102041, 0.36734693877551017, 0.3877551020408163, 0.4081632653061224, 0.42857142857142855, 0.44897959183673464, 0.4693877551020408, 0.4897959183673469, 0.5102040816326531, 0.5306122448979591, 0.5510204081632653, 0.5714285714285714, 0.5918367346938775, 0.6122448979591836, 0.6326530612244897, 0.6530612244897959, 0.673469387755102, 0.6938775510204082, 0.7142857142857142, 0.7346938775510203, 0.7551020408163265, 0.7755102040816326, 0.7959183673469387, 0.8163265306122448, 0.836734693877551, 0.8571428571428571, 0.8775510204081632, 0.8979591836734693, 0.9183673469387754, 0.9387755102040816, 0.9591836734693877, 0.9795918367346939, 1 ], "xaxis": "x", "y": [ 0.21834415584415584, 0.25104884842495634, 0.27839593400868534, 0.2940096168128984, 0.30998254820365995, 0.330310299222406, 0.3458124124945686, 0.36367642050821114, 0.3807561601493285, 0.4028188661395596, 0.412586852734355, 0.4123379591339248, 0.42659133076701666, 0.43193461906560526, 0.4390147491913291, 0.45607963303357263, 0.463581535473485, 0.474367175575354, 0.470087223804696, 0.46407091108241205, 0.46937747874762603, 0.46669322181913886, 0.4746199310374618, 0.4654130077107551, 0.46348314145225333, 0.4540117617664716, 0.4506406669594235, 0.45718627549482566, 0.4499229287495006, 0.4499956872848794, 0.4295211363055229, 0.41809213546858415, 0.4064027737253857, 0.3920695567225927, 0.39219330855018586, 0.3839213295997682, 0.3640042426036393, 0.34443272341041853, 0.32459746029652636, 0.30433164309995714, 0.2841045433982237, 0.2725771884491781, 0.2613378994720565, 0.24690009494842208, 0.23828600664102523, 0.2325924540143871, 0.22415572345869741, 0.22124993965142664, 0.21834415584415584, 0.21834415584415584 ], "yaxis": "y" } ], "layout": { "font": { "size": 12 }, "height": 600, "hoverlabel": { "font": { "size": 16 } }, "legend": { "bgcolor": "rgba(255, 255, 255, 0.5)", "font": { "size": 16 }, "groupclick": "toggleitem", "traceorder": "grouped", "x": 0.01, "xanchor": "left", "y": 0.01, "yanchor": "bottom" }, "margin": { "b": 50, "l": 50, "pad": 0, "r": 0, "t": 35 }, "showlegend": true, "template": { "data": { "bar": [ { "error_x": { "color": "#2a3f5f" }, "error_y": { "color": "#2a3f5f" }, "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "bar" } ], "barpolar": [ { "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "barpolar" } ], "carpet": [ { "aaxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "baxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "type": "carpet" } ], "choropleth": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "choropleth" } ], "contour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "contour" } ], "contourcarpet": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "contourcarpet" } ], "heatmap": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmap" } ], "heatmapgl": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmapgl" } ], "histogram": [ { "marker": { "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "histogram" } ], "histogram2d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2d" } ], "histogram2dcontour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2dcontour" } ], "mesh3d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "mesh3d" } ], "parcoords": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "parcoords" } ], "pie": [ { "automargin": true, "type": "pie" } ], "scatter": [ { "fillpattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 }, "type": "scatter" } ], "scatter3d": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatter3d" } ], "scattercarpet": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattercarpet" } ], "scattergeo": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergeo" } ], "scattergl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergl" } ], "scattermapbox": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattermapbox" } ], "scatterpolar": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolar" } ], "scatterpolargl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolargl" } ], "scatterternary": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterternary" } ], "surface": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "surface" } ], "table": [ { "cells": { "fill": { "color": "#EBF0F8" }, "line": { "color": "white" } }, "header": { "fill": { "color": "#C8D4E3" }, "line": { "color": "white" } }, "type": "table" } ] }, "layout": { "annotationdefaults": { "arrowcolor": "#2a3f5f", "arrowhead": 0, "arrowwidth": 1 }, "autotypenumbers": "strict", "coloraxis": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "colorscale": { "diverging": [ [ 0, "#8e0152" ], [ 0.1, "#c51b7d" ], [ 0.2, "#de77ae" ], [ 0.3, "#f1b6da" ], [ 0.4, "#fde0ef" ], [ 0.5, "#f7f7f7" ], [ 0.6, "#e6f5d0" ], [ 0.7, "#b8e186" ], [ 0.8, "#7fbc41" ], [ 0.9, "#4d9221" ], [ 1, "#276419" ] ], "sequential": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "sequentialminus": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ] }, "colorway": [ "#636efa", "#EF553B", "#00cc96", "#ab63fa", "#FFA15A", "#19d3f3", "#FF6692", "#B6E880", "#FF97FF", "#FECB52" ], "font": { "color": "#2a3f5f" }, "geo": { "bgcolor": "white", "lakecolor": "white", "landcolor": "#E5ECF6", "showlakes": true, "showland": true, "subunitcolor": "white" }, "hoverlabel": { "align": "left" }, "hovermode": "closest", "mapbox": { "style": "light" }, "paper_bgcolor": "white", "plot_bgcolor": "#E5ECF6", "polar": { "angularaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "radialaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "scene": { "xaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "yaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "zaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" } }, "shapedefaults": { "line": { "color": "#2a3f5f" } }, "ternary": { "aaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "baxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "caxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "title": { "x": 0.05 }, "xaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 }, "yaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 } } }, "title": { "font": { "size": 24 }, "pad": { "b": 15, "t": 15 }, "x": 0.5, "xanchor": "center", "xref": "paper", "y": 1, "yanchor": "top" }, "width": 900, "xaxis": { "anchor": "y", "automargin": true, "autorange": true, "domain": [ 0, 1 ], "range": [ 0, 1 ], "title": { "font": { "size": 16 }, "text": "Threshold" }, "type": "linear" }, "yaxis": { "anchor": "x", "automargin": true, "autorange": true, "domain": [ 0, 1 ], "range": [ -0.04712301587301587, 0.8953373015873016 ], "title": { "font": { "size": 16 }, "text": "Score" }, "type": "linear" } } }, "text/html": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Compare how different metrics perform for different thresholds\n", "atom.winner.plot_threshold(metric=[\"f1\", \"accuracy\", \"ap\"], steps=50)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.2" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 4 }