{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Example: Binary classification\n", "--------------------------------\n", "\n", "This example shows how to use ATOM to solve a binary classification problem. Additonnaly, we'll perform a variety of data cleaning steps to prepare the data for modelling.\n", "\n", "The data used is a variation on the [Australian weather dataset](https://www.kaggle.com/jsphyg/weather-dataset-rattle-package) from Kaggle. You can download it from [here](https://github.com/tvdboom/ATOM/blob/master/examples/datasets/weatherAUS.csv). The goal of this dataset is to predict whether or not it will rain tomorrow training a binary classifier on target `RainTomorrow`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load the data" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# Import packages\n", "import pandas as pd\n", "from atom import ATOMClassifier" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
LocationMinTempMaxTempRainfallEvaporationSunshineWindGustDirWindGustSpeedWindDir9amWindDir3pm...Humidity9amHumidity3pmPressure9amPressure3pmCloud9amCloud3pmTemp9amTemp3pmRainTodayRainTomorrow
0MelbourneAirport18.026.921.47.08.9SSE41.0WSSE...95.054.01019.51017.08.05.018.526.0Yes0
1Adelaide17.223.40.0NaNNaNS41.0SWSW...59.036.01015.71015.7NaNNaN17.721.9No0
2Cairns18.624.67.43.06.1SSE54.0SSESE...78.057.01018.71016.63.03.020.824.1Yes0
3Portland13.616.84.21.20.0ESE39.0ESEESE...76.074.01021.41020.57.08.015.616.0Yes1
4Walpole16.419.90.0NaNNaNSE44.0SESE...78.070.01019.41018.9NaNNaN17.418.1No0
\n", "

5 rows × 22 columns

\n", "
" ], "text/plain": [ " Location MinTemp MaxTemp Rainfall Evaporation Sunshine \\\n", "0 MelbourneAirport 18.0 26.9 21.4 7.0 8.9 \n", "1 Adelaide 17.2 23.4 0.0 NaN NaN \n", "2 Cairns 18.6 24.6 7.4 3.0 6.1 \n", "3 Portland 13.6 16.8 4.2 1.2 0.0 \n", "4 Walpole 16.4 19.9 0.0 NaN NaN \n", "\n", " WindGustDir WindGustSpeed WindDir9am WindDir3pm ... Humidity9am \\\n", "0 SSE 41.0 W SSE ... 95.0 \n", "1 S 41.0 S WSW ... 59.0 \n", "2 SSE 54.0 SSE SE ... 78.0 \n", "3 ESE 39.0 ESE ESE ... 76.0 \n", "4 SE 44.0 SE SE ... 78.0 \n", "\n", " Humidity3pm Pressure9am Pressure3pm Cloud9am Cloud3pm Temp9am \\\n", "0 54.0 1019.5 1017.0 8.0 5.0 18.5 \n", "1 36.0 1015.7 1015.7 NaN NaN 17.7 \n", "2 57.0 1018.7 1016.6 3.0 3.0 20.8 \n", "3 74.0 1021.4 1020.5 7.0 8.0 15.6 \n", "4 70.0 1019.4 1018.9 NaN NaN 17.4 \n", "\n", " Temp3pm RainToday RainTomorrow \n", "0 26.0 Yes 0 \n", "1 21.9 No 0 \n", "2 24.1 Yes 0 \n", "3 16.0 Yes 1 \n", "4 18.1 No 0 \n", "\n", "[5 rows x 22 columns]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load data\n", "X = pd.read_csv(\"./datasets/weatherAUS.csv\")\n", "\n", "# Let's have a look\n", "X.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run the pipeline" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<< ================== ATOM ================== >>\n", "Algorithm task: binary classification.\n", "Parallel processing with 8 cores.\n", "Parallelization backend: loky\n", "\n", "Dataset stats ==================== >>\n", "Shape: (7109, 22)\n", "Train set size: 5688\n", "Test set size: 1421\n", "-------------------------------------\n", "Memory: 3.08 MB\n", "Scaled: False\n", "Missing values: 15681 (10.0%)\n", "Categorical features: 5 (23.8%)\n", "Duplicate samples: 2 (0.0%)\n", "\n" ] } ], "source": [ "# Call atom using only 5% of the complete dataset (for explanatory purposes)\n", "atom = ATOMClassifier(X, \"RainTomorrow\", n_rows=0.05, n_jobs=8, verbose=2)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Fitting Imputer...\n", "Imputing missing values...\n", " --> Dropping 11 samples for containing more than 16 missing values.\n", " --> Imputing 23 missing values with median (11.9) in feature MinTemp.\n", " --> Imputing 23 missing values with median (22.4) in feature MaxTemp.\n", " --> Imputing 69 missing values with median (0.0) in feature Rainfall.\n", " --> Imputing 2986 missing values with median (4.8) in feature Evaporation.\n", " --> Imputing 3358 missing values with median (8.4) in feature Sunshine.\n", " --> Dropping 474 samples due to missing values in feature WindGustDir.\n", " --> Imputing 471 missing values with median (39.0) in feature WindGustSpeed.\n", " --> Dropping 490 samples due to missing values in feature WindDir9am.\n", " --> Dropping 179 samples due to missing values in feature WindDir3pm.\n", " --> Imputing 50 missing values with median (13.0) in feature WindSpeed9am.\n", " --> Imputing 121 missing values with median (19.0) in feature WindSpeed3pm.\n", " --> Imputing 73 missing values with median (69.0) in feature Humidity9am.\n", " --> Imputing 176 missing values with median (52.0) in feature Humidity3pm.\n", " --> Imputing 695 missing values with median (1017.6) in feature Pressure9am.\n", " --> Imputing 697 missing values with median (1015.1) in feature Pressure3pm.\n", " --> Imputing 2605 missing values with median (5.0) in feature Cloud9am.\n", " --> Imputing 2756 missing values with median (5.0) in feature Cloud3pm.\n", " --> Imputing 36 missing values with median (16.6) in feature Temp9am.\n", " --> Imputing 131 missing values with median (20.9) in feature Temp3pm.\n", " --> Dropping 69 samples due to missing values in feature RainToday.\n" ] } ], "source": [ "# Impute missing values\n", "atom.impute(strat_num=\"median\", strat_cat=\"drop\", max_nan_rows=0.8)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Fitting Encoder...\n", "Encoding categorical columns...\n", " --> Target-encoding feature Location. Contains 47 classes.\n", " --> Target-encoding feature WindGustDir. Contains 16 classes.\n", " --> Target-encoding feature WindDir9am. Contains 16 classes.\n", " --> Target-encoding feature WindDir3pm. Contains 16 classes.\n", " --> Ordinal-encoding feature RainToday. Contains 2 classes.\n" ] } ], "source": [ "# Encode the categorical features\n", "atom.encode(strategy=\"Target\", max_onehot=10, infrequent_to_value=0.04)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Training ========================= >>\n", "Models: ET, RF\n", "Metric: f1\n", "\n", "\n", "Results for ExtraTrees:\n", "Fit ---------------------------------------------\n", "Train evaluation --> f1: 1.0\n", "Test evaluation --> f1: 0.5688\n", "Time elapsed: 9.395s\n", "Bootstrap ---------------------------------------\n", "Evaluation --> f1: 0.5463 ± 0.0135\n", "Time elapsed: 4.742s\n", "-------------------------------------------------\n", "Total time: 14.138s\n", "\n", "\n", "Results for RandomForest:\n", "Fit ---------------------------------------------\n", "Train evaluation --> f1: 1.0\n", "Test evaluation --> f1: 0.5969\n", "Time elapsed: 1.368s\n", "Bootstrap ---------------------------------------\n", "Evaluation --> f1: 0.576 ± 0.0117\n", "Time elapsed: 5.341s\n", "-------------------------------------------------\n", "Total time: 6.709s\n", "\n", "\n", "Final results ==================== >>\n", "Total time: 20.861s\n", "-------------------------------------\n", "ExtraTrees --> f1: 0.5463 ± 0.0135 ~\n", "RandomForest --> f1: 0.576 ± 0.0117 ~ !\n" ] } ], "source": [ "# Train an Extra-Trees and a Random Forest model\n", "atom.run(models=[\"ET\", \"RF\"], metric=\"f1\", n_bootstrap=5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Analyze the results" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
score_trainscore_testtime_fitscore_bootstraptime_bootstraptime
ET1.00.56889.3954850.5463074.74229214.137777
RF1.00.59691.3677600.5759955.3411016.708861
\n", "
" ], "text/plain": [ " score_train score_test time_fit score_bootstrap time_bootstrap \\\n", "ET 1.0 0.5688 9.395485 0.546307 4.742292 \n", "RF 1.0 0.5969 1.367760 0.575995 5.341101 \n", "\n", " time \n", "ET 14.137777 \n", "RF 6.708861 " ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Let's have a look at the final results\n", "atom.results" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ " \n", " " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.plotly.v1+json": { "config": { "plotlyServerURL": "https://plot.ly" }, "data": [ { "boxpoints": "outliers", "legendgroup": "f1", "marker": { "color": "rgb(0, 98, 98)" }, "name": "f1", "orientation": "h", "showlegend": true, "type": "box", "x": [ 0.5305164319248827, 0.5505882352941176, 0.5668202764976958, 0.5433255269320844, 0.5402843601895735, 0.588495575221239, 0.5637583892617449, 0.5649202733485194, 0.5758928571428571, 0.5869074492099323 ], "xaxis": "x", "y": [ "ET", "ET", "ET", "ET", "ET", "RF", "RF", "RF", "RF", "RF" ], "yaxis": "y" } ], "layout": { "bargroupgap": 0.05, "boxmode": "group", "font": { "size": 12 }, "height": 500, "hoverlabel": { "font": { "size": 16 } }, "legend": { "bgcolor": "rgba(255, 255, 255, 0.5)", "font": { "size": 16 }, "groupclick": "toggleitem", "traceorder": "grouped", "x": 0.99, "xanchor": "right", "y": 0.01, "yanchor": "bottom" }, "margin": { "b": 50, "l": 50, "pad": 0, "r": 0, "t": 49 }, "showlegend": true, "template": { "data": { "bar": [ { "error_x": { "color": "#2a3f5f" }, "error_y": { "color": "#2a3f5f" }, "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "bar" } ], "barpolar": [ { "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "barpolar" } ], "carpet": [ { "aaxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "baxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "type": "carpet" } ], "choropleth": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "choropleth" } ], "contour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "contour" } ], "contourcarpet": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "contourcarpet" } ], "heatmap": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmap" } ], "heatmapgl": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmapgl" } ], "histogram": [ { "marker": { "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "histogram" } ], "histogram2d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2d" } ], "histogram2dcontour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2dcontour" } ], "mesh3d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "mesh3d" } ], "parcoords": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "parcoords" } ], "pie": [ { "automargin": true, "type": "pie" } ], "scatter": [ { "fillpattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 }, "type": "scatter" } ], "scatter3d": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatter3d" } ], "scattercarpet": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattercarpet" } ], "scattergeo": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergeo" } ], "scattergl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergl" } ], "scattermapbox": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattermapbox" } ], "scatterpolar": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolar" } ], "scatterpolargl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolargl" } ], "scatterternary": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterternary" } ], "surface": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "surface" } ], "table": [ { "cells": { "fill": { "color": "#EBF0F8" }, "line": { "color": "white" } }, "header": { "fill": { "color": "#C8D4E3" }, "line": { "color": "white" } }, "type": "table" } ] }, "layout": { "annotationdefaults": { "arrowcolor": "#2a3f5f", "arrowhead": 0, "arrowwidth": 1 }, "autotypenumbers": "strict", "coloraxis": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "colorscale": { "diverging": [ [ 0, "#8e0152" ], [ 0.1, "#c51b7d" ], [ 0.2, "#de77ae" ], [ 0.3, "#f1b6da" ], [ 0.4, "#fde0ef" ], [ 0.5, "#f7f7f7" ], [ 0.6, "#e6f5d0" ], [ 0.7, "#b8e186" ], [ 0.8, "#7fbc41" ], [ 0.9, "#4d9221" ], [ 1, "#276419" ] ], "sequential": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "sequentialminus": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ] }, "colorway": [ "#636efa", "#EF553B", "#00cc96", "#ab63fa", "#FFA15A", "#19d3f3", "#FF6692", "#B6E880", "#FF97FF", "#FECB52" ], "font": { "color": "#2a3f5f" }, "geo": { "bgcolor": "white", "lakecolor": "white", "landcolor": "#E5ECF6", "showlakes": true, "showland": true, "subunitcolor": "white" }, "hoverlabel": { "align": "left" }, "hovermode": "closest", "mapbox": { "style": "light" }, "paper_bgcolor": "white", "plot_bgcolor": "#E5ECF6", "polar": { "angularaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "radialaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "scene": { "xaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "yaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "zaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" } }, "shapedefaults": { "line": { "color": "#2a3f5f" } }, "ternary": { "aaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "baxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "caxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "title": { "x": 0.05 }, "xaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 }, "yaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 } } }, "title": { "font": { "size": 24 }, "pad": { "b": 15, "t": 15 }, "text": "RF vs ET performance", "x": 0.5, "xanchor": "center", "xref": "paper", "y": 1, "yanchor": "top" }, "width": 900, "xaxis": { "anchor": "y", "automargin": true, "autorange": true, "domain": [ 0, 1 ], "range": [ 0.5272953684084185, 0.5917166387377032 ], "title": { "font": { "size": 16 }, "text": "time (s)" }, "type": "linear" }, "yaxis": { "anchor": "x", "automargin": true, "autorange": true, "categoryorder": "total ascending", "domain": [ 0, 1 ], "range": [ -0.5, 1.5 ], "title": { "font": { "size": 16 } }, "type": "category" } } }, "text/html": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Visualize the bootstrap results\n", "atom.plot_results(title=\"RF vs ET performance\")" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
accuracyaverage_precisionbalanced_accuracyf1jaccardmatthews_corrcoefprecisionrecallroc_auc
ET0.84270.68130.69110.54030.37010.48280.75500.42070.8613
RF0.85160.68710.71800.58690.41530.52120.75580.47970.8652
\n", "
" ], "text/plain": [ " accuracy average_precision balanced_accuracy f1 jaccard \\\n", "ET 0.8427 0.6813 0.6911 0.5403 0.3701 \n", "RF 0.8516 0.6871 0.7180 0.5869 0.4153 \n", "\n", " matthews_corrcoef precision recall roc_auc \n", "ET 0.4828 0.7550 0.4207 0.8613 \n", "RF 0.5212 0.7558 0.4797 0.8652 " ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Print the results of some common metrics\n", "atom.evaluate()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The winner is the RF model!!\n" ] } ], "source": [ "# The winner attribute calls the best model (atom.winner == atom.rf)\n", "print(f\"The winner is the {atom.winner.name} model!!\")" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "application/vnd.plotly.v1+json": { "config": { "plotlyServerURL": "https://plot.ly" }, "data": [ { "fill": "tonexty", "fillcolor": "rgba(0, 98, 98, 0.2)", "fillpattern": { "shape": "" }, "legendgroup": "RF", "legendgrouptitle": { "font": { "size": 16 }, "text": "RF" }, "line": { "color": "rgb(0, 98, 98)", "width": 2 }, "mode": "lines", "name": "RainTomorrow=0", "showlegend": true, "type": "scatter", "x": [ 0, 0.010101010101010102, 0.020202020202020204, 0.030303030303030304, 0.04040404040404041, 0.05050505050505051, 0.06060606060606061, 0.07070707070707072, 0.08080808080808081, 0.09090909090909091, 0.10101010101010102, 0.11111111111111112, 0.12121212121212122, 0.13131313131313133, 0.14141414141414144, 0.15151515151515152, 0.16161616161616163, 0.17171717171717174, 0.18181818181818182, 0.19191919191919193, 0.20202020202020204, 0.21212121212121213, 0.22222222222222224, 0.23232323232323235, 0.24242424242424243, 0.25252525252525254, 0.26262626262626265, 0.27272727272727276, 0.2828282828282829, 0.29292929292929293, 0.30303030303030304, 0.31313131313131315, 0.32323232323232326, 0.33333333333333337, 0.3434343434343435, 0.3535353535353536, 0.36363636363636365, 0.37373737373737376, 0.38383838383838387, 0.393939393939394, 0.4040404040404041, 0.4141414141414142, 0.42424242424242425, 0.43434343434343436, 0.4444444444444445, 0.4545454545454546, 0.4646464646464647, 0.4747474747474748, 0.48484848484848486, 0.494949494949495, 0.5050505050505051, 0.5151515151515152, 0.5252525252525253, 0.5353535353535354, 0.5454545454545455, 0.5555555555555556, 0.5656565656565657, 0.5757575757575758, 0.5858585858585859, 0.595959595959596, 0.6060606060606061, 0.6161616161616162, 0.6262626262626263, 0.6363636363636365, 0.6464646464646465, 0.6565656565656566, 0.6666666666666667, 0.6767676767676768, 0.686868686868687, 0.696969696969697, 0.7070707070707072, 0.7171717171717172, 0.7272727272727273, 0.7373737373737375, 0.7474747474747475, 0.7575757575757577, 0.7676767676767677, 0.7777777777777778, 0.787878787878788, 0.797979797979798, 0.8080808080808082, 0.8181818181818182, 0.8282828282828284, 0.8383838383838385, 0.8484848484848485, 0.8585858585858587, 0.8686868686868687, 0.8787878787878789, 0.888888888888889, 0.8989898989898991, 0.9090909090909092, 0.9191919191919192, 0.9292929292929294, 0.9393939393939394, 0.9494949494949496, 0.9595959595959597, 0.9696969696969697, 0.9797979797979799, 0.98989898989899, 1 ], "xaxis": "x", "y": [ 3.0933689063475573, 3.5579862850448816, 3.9285991866078875, 4.176306867946532, 4.288536269256913, 4.269940676345758, 4.1396135415703865, 3.9258168419930044, 3.6599344704108883, 3.371223943638619, 3.0833456187060007, 2.8129179020357085, 2.5697597389434694, 2.358176917155678, 2.1786194193084634, 2.02919117959044, 1.9067250758778036, 1.8073651345075599, 1.7267755781191196, 1.660192888880494, 1.6025385249050115, 1.54872514386165, 1.4941553900321678, 1.4352866321796072, 1.3700720496308434, 1.2981153634388398, 1.2204798433013844, 1.1392233448381313, 1.05682875386417, 0.9757189534939114, 0.897981726642871, 0.8253182761781456, 0.7591263453144552, 0.7005839377433689, 0.6506287921856995, 0.6098106508928158, 0.5780824799749964, 0.5546475764148895, 0.5379679126231703, 0.5259723260817694, 0.5164144524374739, 0.5072624005979143, 0.4969875578841436, 0.48466573622473386, 0.46988804987670535, 0.45256002467727796, 0.4327053706871094, 0.41036869614351373, 0.3856454627302735, 0.35879703611334723, 0.3303712359238004, 0.3012572115104722, 0.2726420421479381, 0.24587705798230516, 0.2222855855177291, 0.20295037665864796, 0.18851908650777527, 0.17906697314377107, 0.17405394540573696, 0.17239921273051123, 0.1726677852411491, 0.1733275566706943, 0.17301048974605118, 0.17071084396519684, 0.16587946360817285, 0.15841333165729987, 0.14857373096404533, 0.13687862971988693, 0.12400319638390378, 0.11069861509178927, 0.09772085121841993, 0.085758454674432, 0.07535985647006689, 0.06687431058340534, 0.06042488142753736, 0.05592288101998652, 0.05311663663474557, 0.051653813630086004, 0.05113394536043157, 0.051137830549875266, 0.051237366633419615, 0.05100344219751991, 0.05003203998946364, 0.04799715972619033, 0.044719386714712483, 0.04022252656982395, 0.03474837580433914, 0.028714231288953732, 0.02262189589373659, 0.016947689011280978, 0.012049583098347368, 0.008117828048597283, 0.005175957434518472, 0.0031204374597477093, 0.0017774248298854816, 9.560044070347473E-4, 4.853023469668028E-4, 2.3242186768467345E-4, 1.0498096461505306E-4, 4.470871658505219E-5 ], "yaxis": "y" }, { "fill": "tonexty", "fillcolor": "rgba(0, 98, 98, 0.2)", "fillpattern": { "shape": "/" }, "legendgroup": "RF", "legendgrouptitle": { "font": { "size": 16 }, "text": "RF" }, "line": { "color": "rgb(0, 98, 98)", "width": 2 }, "mode": "lines", "name": "RainTomorrow=1", "showlegend": true, "type": "scatter", "x": [ 0, 0.010101010101010102, 0.020202020202020204, 0.030303030303030304, 0.04040404040404041, 0.05050505050505051, 0.06060606060606061, 0.07070707070707072, 0.08080808080808081, 0.09090909090909091, 0.10101010101010102, 0.11111111111111112, 0.12121212121212122, 0.13131313131313133, 0.14141414141414144, 0.15151515151515152, 0.16161616161616163, 0.17171717171717174, 0.18181818181818182, 0.19191919191919193, 0.20202020202020204, 0.21212121212121213, 0.22222222222222224, 0.23232323232323235, 0.24242424242424243, 0.25252525252525254, 0.26262626262626265, 0.27272727272727276, 0.2828282828282829, 0.29292929292929293, 0.30303030303030304, 0.31313131313131315, 0.32323232323232326, 0.33333333333333337, 0.3434343434343435, 0.3535353535353536, 0.36363636363636365, 0.37373737373737376, 0.38383838383838387, 0.393939393939394, 0.4040404040404041, 0.4141414141414142, 0.42424242424242425, 0.43434343434343436, 0.4444444444444445, 0.4545454545454546, 0.4646464646464647, 0.4747474747474748, 0.48484848484848486, 0.494949494949495, 0.5050505050505051, 0.5151515151515152, 0.5252525252525253, 0.5353535353535354, 0.5454545454545455, 0.5555555555555556, 0.5656565656565657, 0.5757575757575758, 0.5858585858585859, 0.595959595959596, 0.6060606060606061, 0.6161616161616162, 0.6262626262626263, 0.6363636363636365, 0.6464646464646465, 0.6565656565656566, 0.6666666666666667, 0.6767676767676768, 0.686868686868687, 0.696969696969697, 0.7070707070707072, 0.7171717171717172, 0.7272727272727273, 0.7373737373737375, 0.7474747474747475, 0.7575757575757577, 0.7676767676767677, 0.7777777777777778, 0.787878787878788, 0.797979797979798, 0.8080808080808082, 0.8181818181818182, 0.8282828282828284, 0.8383838383838385, 0.8484848484848485, 0.8585858585858587, 0.8686868686868687, 0.8787878787878789, 0.888888888888889, 0.8989898989898991, 0.9090909090909092, 0.9191919191919192, 0.9292929292929294, 0.9393939393939394, 0.9494949494949496, 0.9595959595959597, 0.9696969696969697, 0.9797979797979799, 0.98989898989899, 1 ], "xaxis": "x", "y": [ 0.3578196795805163, 0.39903613801665105, 0.44114568514060254, 0.4836832769782436, 0.526195418378348, 0.5682607702730456, 0.6095085467995134, 0.649633362622719, 0.6884054303684506, 0.7256753337641916, 0.7613729932109744, 0.7955008743145477, 0.8281219413983371, 0.8593432989051414, 0.8892968611346839, 0.9181187077215955, 0.9459289788759886, 0.9728142029645707, 0.9988138006857528, 1.0239121630432009, 1.0480371671206254, 1.0710653159381236, 1.0928329369926684, 1.1131521418580925, 1.1318296406603363, 1.1486861188883777, 1.1635737945212932, 1.1763900153660203, 1.1870853134553874, 1.1956651358688235, 1.2021854043871603, 1.2067429762697262, 1.2094628362762962, 1.2104843178518117, 1.209948745630727, 1.2079905895494394, 1.2047335672110475, 1.2002922308343695, 1.1947785767791534, 1.188312286134287, 1.1810325005334548, 1.1731086771836539, 1.1647481124969332, 1.1561981687715996, 1.1477420117321315, 1.139687644756075, 1.1323510538962325, 1.1260351971422102, 1.1210072433868978, 1.1174767947495294, 1.115577768720398, 1.115356191083087, 1.1167654265675893, 1.119669460776384, 1.1238538740416184, 1.1290432463511015, 1.13492301577845, 1.1411633629729412, 1.1474425534628354, 1.1534673382274343, 1.158988453394055, 1.1638099040109624, 1.167791477235819, 1.1708447120599317, 1.1729232649262737, 1.1740091762061446, 1.1740969057469361, 1.1731771369348678, 1.1712222461134816, 1.1681750226552354, 1.163941752315061, 1.1583902073415644, 1.151352494787833, 1.142632173198502, 1.132014621974048, 1.1192793846321236, 1.1042131308731518, 1.0866219901012575, 1.0663422718564912, 1.0432489554945534, 1.0172617372470631, 0.988348798509153, 0.9565287432837905, 0.9218713014344564, 0.8844973894056856, 0.8445789716293394, 0.8023389109704677, 0.75805069227607, 0.7120376157417723, 0.664670849687823, 0.6163656536766163, 0.5675751567692635, 0.5187812974506649, 0.4704828688909871, 0.4231810116246363, 0.37736288967125065, 0.33348461019590886, 0.29195464745858907, 0.2531190762493785, 0.21724980100253127 ], "yaxis": "y" } ], "layout": { "font": { "size": 12 }, "height": 600, "hoverlabel": { "font": { "size": 16 } }, "legend": { "bgcolor": "rgba(255, 255, 255, 0.5)", "font": { "size": 16 }, "groupclick": "toggleitem", "traceorder": "grouped", "x": 0.99, "xanchor": "right", "y": 0.99, "yanchor": "top" }, "margin": { "b": 50, "l": 50, "pad": 0, "r": 0, "t": 35 }, "showlegend": true, "template": { "data": { "bar": [ { "error_x": { "color": "#2a3f5f" }, "error_y": { "color": "#2a3f5f" }, "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "bar" } ], "barpolar": [ { "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "barpolar" } ], "carpet": [ { "aaxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "baxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "type": "carpet" } ], "choropleth": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "choropleth" } ], "contour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "contour" } ], "contourcarpet": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "contourcarpet" } ], "heatmap": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmap" } ], "heatmapgl": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmapgl" } ], "histogram": [ { "marker": { "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "histogram" } ], "histogram2d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2d" } ], "histogram2dcontour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2dcontour" } ], "mesh3d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "mesh3d" } ], "parcoords": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "parcoords" } ], "pie": [ { "automargin": true, "type": "pie" } ], "scatter": [ { "fillpattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 }, "type": "scatter" } ], "scatter3d": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatter3d" } ], "scattercarpet": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattercarpet" } ], "scattergeo": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergeo" } ], "scattergl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergl" } ], "scattermapbox": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattermapbox" } ], "scatterpolar": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolar" } ], "scatterpolargl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolargl" } ], "scatterternary": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterternary" } ], "surface": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "surface" } ], "table": [ { "cells": { "fill": { "color": "#EBF0F8" }, "line": { "color": "white" } }, "header": { "fill": { "color": "#C8D4E3" }, "line": { "color": "white" } }, "type": "table" } ] }, "layout": { "annotationdefaults": { "arrowcolor": "#2a3f5f", "arrowhead": 0, "arrowwidth": 1 }, "autotypenumbers": "strict", "coloraxis": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "colorscale": { "diverging": [ [ 0, "#8e0152" ], [ 0.1, "#c51b7d" ], [ 0.2, "#de77ae" ], [ 0.3, "#f1b6da" ], [ 0.4, "#fde0ef" ], [ 0.5, "#f7f7f7" ], [ 0.6, "#e6f5d0" ], [ 0.7, "#b8e186" ], [ 0.8, "#7fbc41" ], [ 0.9, "#4d9221" ], [ 1, "#276419" ] ], "sequential": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "sequentialminus": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ] }, "colorway": [ "#636efa", "#EF553B", "#00cc96", "#ab63fa", "#FFA15A", "#19d3f3", "#FF6692", "#B6E880", "#FF97FF", "#FECB52" ], "font": { "color": "#2a3f5f" }, "geo": { "bgcolor": "white", "lakecolor": "white", "landcolor": "#E5ECF6", "showlakes": true, "showland": true, "subunitcolor": "white" }, "hoverlabel": { "align": "left" }, "hovermode": "closest", "mapbox": { "style": "light" }, "paper_bgcolor": "white", "plot_bgcolor": "#E5ECF6", "polar": { "angularaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "radialaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "scene": { "xaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "yaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "zaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" } }, "shapedefaults": { "line": { "color": "#2a3f5f" } }, "ternary": { "aaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "baxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "caxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "title": { "x": 0.05 }, "xaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 }, "yaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 } } }, "title": { "font": { "size": 24 }, "pad": { "b": 15, "t": 15 }, "x": 0.5, "xanchor": "center", "xref": "paper", "y": 1, "yanchor": "top" }, "width": 900, "xaxis": { "anchor": "y", "automargin": true, "domain": [ 0, 1 ], "range": [ 0, 1 ], "title": { "font": { "size": 16 }, "text": "Probability" }, "type": "linear" }, "yaxis": { "anchor": "x", "automargin": true, "autorange": true, "domain": [ 0, 1 ], "range": [ -0.0089327805671566, 4.514718850826601 ], "title": { "font": { "size": 16 }, "text": "Probability density" }, "type": "linear" } } }, "text/html": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Visualize the distribution of predicted probabilities\n", "atom.winner.plot_probabilities()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "application/vnd.plotly.v1+json": { "config": { "plotlyServerURL": "https://plot.ly" }, "data": [ { "hovertemplate": "(%{x}, %{y})RF - f1_score", "legendgroup": "RF", "legendgrouptitle": { "font": { "size": 16 }, "text": "RF" }, "line": { "color": "rgb(0, 98, 98)", "width": 2 }, "marker": { "color": "rgb(0, 98, 98)", "line": { "color": "rgba(255, 255, 255, 0.9)", "width": 1 }, "size": 8, "symbol": "circle" }, "name": "f1_score", "showlegend": true, "type": "scatter", "x": [ 0, 0.02040816326530612, 0.04081632653061224, 0.061224489795918366, 0.08163265306122448, 0.1020408163265306, 0.12244897959183673, 0.14285714285714285, 0.16326530612244897, 0.18367346938775508, 0.2040816326530612, 0.22448979591836732, 0.24489795918367346, 0.26530612244897955, 0.2857142857142857, 0.3061224489795918, 0.32653061224489793, 0.3469387755102041, 0.36734693877551017, 0.3877551020408163, 0.4081632653061224, 0.42857142857142855, 0.44897959183673464, 0.4693877551020408, 0.4897959183673469, 0.5102040816326531, 0.5306122448979591, 0.5510204081632653, 0.5714285714285714, 0.5918367346938775, 0.6122448979591836, 0.6326530612244897, 0.6530612244897959, 0.673469387755102, 0.6938775510204082, 0.7142857142857142, 0.7346938775510203, 0.7551020408163265, 0.7755102040816326, 0.7959183673469387, 0.8163265306122448, 0.836734693877551, 0.8571428571428571, 0.8775510204081632, 0.8979591836734693, 0.9183673469387754, 0.9387755102040816, 0.9591836734693877, 0.9795918367346939, 1 ], "xaxis": "x", "y": [ 0.3603723404255319, 0.4018058690744921, 0.4370860927152318, 0.47368421052631576, 0.495164410058027, 0.5195071868583162, 0.5347826086956522, 0.5508571428571428, 0.5690276110444178, 0.5760598503740648, 0.5908496732026144, 0.6065573770491803, 0.6242774566473989, 0.6334841628959277, 0.629746835443038, 0.6214876033057851, 0.6348122866894198, 0.6373239436619718, 0.6304347826086957, 0.6245353159851301, 0.6165703275529866, 0.6099009900990099, 0.5945945945945945, 0.5872340425531916, 0.5903083700440529, 0.5785876993166287, 0.5714285714285714, 0.557142857142857, 0.5330073349633251, 0.5087281795511222, 0.4923076923076924, 0.45744680851063824, 0.44141689373297, 0.41666666666666663, 0.4034090909090909, 0.36842105263157887, 0.34029850746268653, 0.31610942249240126, 0.2795031055900621, 0.2315112540192926, 0.19672131147540983, 0.1610738255033557, 0.13698630136986298, 0.11805555555555554, 0.0642857142857143, 0.04332129963898917, 0.0218978102189781, 0, 0, 0 ], "yaxis": "y" }, { "hovertemplate": "(%{x}, %{y})RF - accuracy_score", "legendgroup": "RF", "legendgrouptitle": { "font": { "size": 16 }, "text": "RF" }, "line": { "color": "rgb(0, 98, 98)", "dash": "dashdot", "width": 2 }, "marker": { "color": "rgb(0, 98, 98)", "line": { "color": "rgba(255, 255, 255, 0.9)", "width": 1 }, "size": 8, "symbol": "x" }, "name": "accuracy_score", "showlegend": true, "type": "scatter", "x": [ 0, 0.02040816326530612, 0.04081632653061224, 0.061224489795918366, 0.08163265306122448, 0.1020408163265306, 0.12244897959183673, 0.14285714285714285, 0.16326530612244897, 0.18367346938775508, 0.2040816326530612, 0.22448979591836732, 0.24489795918367346, 0.26530612244897955, 0.2857142857142857, 0.3061224489795918, 0.32653061224489793, 0.3469387755102041, 0.36734693877551017, 0.3877551020408163, 0.4081632653061224, 0.42857142857142855, 0.44897959183673464, 0.4693877551020408, 0.4897959183673469, 0.5102040816326531, 0.5306122448979591, 0.5510204081632653, 0.5714285714285714, 0.5918367346938775, 0.6122448979591836, 0.6326530612244897, 0.6530612244897959, 0.673469387755102, 0.6938775510204082, 0.7142857142857142, 0.7346938775510203, 0.7551020408163265, 0.7755102040816326, 0.7959183673469387, 0.8163265306122448, 0.836734693877551, 0.8571428571428571, 0.8775510204081632, 0.8979591836734693, 0.9183673469387754, 0.9387755102040816, 0.9591836734693877, 0.9795918367346939, 1 ], "xaxis": "x", "y": [ 0.2197891321978913, 0.35523114355231145, 0.44849959448499593, 0.5296025952960259, 0.5766423357664233, 0.6204379562043796, 0.6528791565287916, 0.681265206812652, 0.7088402270884022, 0.7242497972424979, 0.7461476074614761, 0.7664233576642335, 0.7891321978913219, 0.8029197080291971, 0.8102189781021898, 0.8142741281427412, 0.8264395782643957, 0.8329278183292782, 0.8345498783454988, 0.8361719383617194, 0.8386050283860503, 0.8402270884022709, 0.8418491484184915, 0.8426601784266018, 0.8491484184914841, 0.8499594484995945, 0.851581508515815, 0.8491484184914841, 0.8450932684509327, 0.8402270884022709, 0.8394160583941606, 0.8345498783454988, 0.8337388483373885, 0.829683698296837, 0.829683698296837, 0.8248175182481752, 0.8207623682076237, 0.8175182481751825, 0.8118410381184104, 0.8061638280616383, 0.8012976480129764, 0.797242497972425, 0.7956204379562044, 0.7939983779399837, 0.7875101378751014, 0.7850770478507705, 0.7826439578264396, 0.7802108678021087, 0.7802108678021087, 0.7802108678021087 ], "yaxis": "y" }, { "hovertemplate": "(%{x}, %{y})RF - average_precision_score", "legendgroup": "RF", "legendgrouptitle": { "font": { "size": 16 }, "text": "RF" }, "line": { "color": "rgb(0, 98, 98)", "dash": "dash", "width": 2 }, "marker": { "color": "rgb(0, 98, 98)", "line": { "color": "rgba(255, 255, 255, 0.9)", "width": 1 }, "size": 8, "symbol": "diamond" }, "name": "average_precision_score", "showlegend": true, "type": "scatter", "x": [ 0, 0.02040816326530612, 0.04081632653061224, 0.061224489795918366, 0.08163265306122448, 0.1020408163265306, 0.12244897959183673, 0.14285714285714285, 0.16326530612244897, 0.18367346938775508, 0.2040816326530612, 0.22448979591836732, 0.24489795918367346, 0.26530612244897955, 0.2857142857142857, 0.3061224489795918, 0.32653061224489793, 0.3469387755102041, 0.36734693877551017, 0.3877551020408163, 0.4081632653061224, 0.42857142857142855, 0.44897959183673464, 0.4693877551020408, 0.4897959183673469, 0.5102040816326531, 0.5306122448979591, 0.5510204081632653, 0.5714285714285714, 0.5918367346938775, 0.6122448979591836, 0.6326530612244897, 0.6530612244897959, 0.673469387755102, 0.6938775510204082, 0.7142857142857142, 0.7346938775510203, 0.7551020408163265, 0.7755102040816326, 0.7959183673469387, 0.8163265306122448, 0.836734693877551, 0.8571428571428571, 0.8775510204081632, 0.8979591836734693, 0.9183673469387754, 0.9387755102040816, 0.9591836734693877, 0.9795918367346939, 1 ], "xaxis": "x", "y": [ 0.2197891321978913, 0.25188215461694585, 0.2801497946972396, 0.3106000758813881, 0.3291120534015798, 0.35058086156237167, 0.36435321254902964, 0.3791670479425343, 0.39637516735195405, 0.4032586373123412, 0.4180192749059498, 0.434230228744012, 0.45354337639289843, 0.46460198178624124, 0.4631840533582156, 0.4577965212319635, 0.474209031975617, 0.48002737713729715, 0.4762484061446064, 0.4736020139665023, 0.4709313657770437, 0.4688773277401499, 0.463133577112622, 0.4609979595742845, 0.4731783383407605, 0.4710538281086669, 0.47291083803594913, 0.46391147446476666, 0.4490776469173768, 0.43238056299709543, 0.42770655456048523, 0.4099597221206831, 0.4062866609505511, 0.3921805059106935, 0.3918538536561359, 0.3749725016526406, 0.3608874512558994, 0.3496476369243934, 0.3298089528769912, 0.31014924747787626, 0.29313567931729123, 0.27904519921111615, 0.2738549491855711, 0.26873224936628926, 0.24570019422821965, 0.23706317355144355, 0.22842615287466742, 0.2197891321978913, 0.2197891321978913, 0.2197891321978913 ], "yaxis": "y" } ], "layout": { "font": { "size": 12 }, "height": 600, "hoverlabel": { "font": { "size": 16 } }, "legend": { "bgcolor": "rgba(255, 255, 255, 0.5)", "font": { "size": 16 }, "groupclick": "toggleitem", "traceorder": "grouped", "x": 0.01, "xanchor": "left", "y": 0.01, "yanchor": "bottom" }, "margin": { "b": 50, "l": 50, "pad": 0, "r": 0, "t": 35 }, "showlegend": true, "template": { "data": { "bar": [ { "error_x": { "color": "#2a3f5f" }, "error_y": { "color": "#2a3f5f" }, "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "bar" } ], "barpolar": [ { "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "barpolar" } ], "carpet": [ { "aaxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "baxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "type": "carpet" } ], "choropleth": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "choropleth" } ], "contour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "contour" } ], "contourcarpet": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "contourcarpet" } ], "heatmap": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmap" } ], "heatmapgl": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmapgl" } ], "histogram": [ { "marker": { "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "histogram" } ], "histogram2d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2d" } ], "histogram2dcontour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2dcontour" } ], "mesh3d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "mesh3d" } ], "parcoords": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "parcoords" } ], "pie": [ { "automargin": true, "type": "pie" } ], "scatter": [ { "fillpattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 }, "type": "scatter" } ], "scatter3d": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatter3d" } ], "scattercarpet": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattercarpet" } ], "scattergeo": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergeo" } ], "scattergl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergl" } ], "scattermapbox": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattermapbox" } ], "scatterpolar": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolar" } ], "scatterpolargl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolargl" } ], "scatterternary": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterternary" } ], "surface": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "surface" } ], "table": [ { "cells": { "fill": { "color": "#EBF0F8" }, "line": { "color": "white" } }, "header": { "fill": { "color": "#C8D4E3" }, "line": { "color": "white" } }, "type": "table" } ] }, "layout": { "annotationdefaults": { "arrowcolor": "#2a3f5f", "arrowhead": 0, "arrowwidth": 1 }, "autotypenumbers": "strict", "coloraxis": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "colorscale": { "diverging": [ [ 0, "#8e0152" ], [ 0.1, "#c51b7d" ], [ 0.2, "#de77ae" ], [ 0.3, "#f1b6da" ], [ 0.4, "#fde0ef" ], [ 0.5, "#f7f7f7" ], [ 0.6, "#e6f5d0" ], [ 0.7, "#b8e186" ], [ 0.8, "#7fbc41" ], [ 0.9, "#4d9221" ], [ 1, "#276419" ] ], "sequential": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "sequentialminus": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ] }, "colorway": [ "#636efa", "#EF553B", "#00cc96", "#ab63fa", "#FFA15A", "#19d3f3", "#FF6692", "#B6E880", "#FF97FF", "#FECB52" ], "font": { "color": "#2a3f5f" }, "geo": { "bgcolor": "white", "lakecolor": "white", "landcolor": "#E5ECF6", "showlakes": true, "showland": true, "subunitcolor": "white" }, "hoverlabel": { "align": "left" }, "hovermode": "closest", "mapbox": { "style": "light" }, "paper_bgcolor": "white", "plot_bgcolor": "#E5ECF6", "polar": { "angularaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "radialaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "scene": { "xaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "yaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "zaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" } }, "shapedefaults": { "line": { "color": "#2a3f5f" } }, "ternary": { "aaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "baxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "caxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "title": { "x": 0.05 }, "xaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 }, "yaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 } } }, "title": { "font": { "size": 24 }, "pad": { "b": 15, "t": 15 }, "x": 0.5, "xanchor": "center", "xref": "paper", "y": 1, "yanchor": "top" }, "width": 900, "xaxis": { "anchor": "y", "automargin": true, "autorange": true, "domain": [ 0, 1 ], "range": [ 0, 1 ], "title": { "font": { "size": 16 }, "text": "Threshold" }, "type": "linear" }, "yaxis": { "anchor": "x", "automargin": true, "autorange": true, "domain": [ 0, 1 ], "range": [ -0.047310083806434165, 0.8988915923222492 ], "title": { "font": { "size": 16 }, "text": "Score" }, "type": "linear" } } }, "text/html": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Compare how different metrics perform for different thresholds\n", "atom.winner.plot_threshold(metric=[\"f1\", \"accuracy\", \"ap\"], steps=50)\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 4 }