{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Example: Train sizing\n", "-----------------------\n", "\n", "This example shows how to asses a model's performance based on the size of the training set.\n", "\n", "The data used is a variation on the [Australian weather dataset](https://www.kaggle.com/jsphyg/weather-dataset-rattle-package) from Kaggle. You can download it from [here](https://github.com/tvdboom/ATOM/blob/master/examples/datasets/weatherAUS.csv). The goal of this dataset is to predict whether or not it will rain tomorrow training a binary classifier on target `RainTomorrow`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load the data" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "UserWarning: The pandas version installed (1.5.3) does not match the supported pandas version in Modin (1.5.2). This may cause undesired side effects!\n" ] } ], "source": [ "# Import packages\n", "import pandas as pd\n", "from atom import ATOMClassifier" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
LocationMinTempMaxTempRainfallEvaporationSunshineWindGustDirWindGustSpeedWindDir9amWindDir3pm...Humidity9amHumidity3pmPressure9amPressure3pmCloud9amCloud3pmTemp9amTemp3pmRainTodayRainTomorrow
0MelbourneAirport18.026.921.47.08.9SSE41.0WSSE...95.054.01019.51017.08.05.018.526.0Yes0
1Adelaide17.223.40.0NaNNaNS41.0SWSW...59.036.01015.71015.7NaNNaN17.721.9No0
2Cairns18.624.67.43.06.1SSE54.0SSESE...78.057.01018.71016.63.03.020.824.1Yes0
3Portland13.616.84.21.20.0ESE39.0ESEESE...76.074.01021.41020.57.08.015.616.0Yes1
4Walpole16.419.90.0NaNNaNSE44.0SESE...78.070.01019.41018.9NaNNaN17.418.1No0
\n", "

5 rows × 22 columns

\n", "
" ], "text/plain": [ " Location MinTemp MaxTemp Rainfall Evaporation Sunshine \\\n", "0 MelbourneAirport 18.0 26.9 21.4 7.0 8.9 \n", "1 Adelaide 17.2 23.4 0.0 NaN NaN \n", "2 Cairns 18.6 24.6 7.4 3.0 6.1 \n", "3 Portland 13.6 16.8 4.2 1.2 0.0 \n", "4 Walpole 16.4 19.9 0.0 NaN NaN \n", "\n", " WindGustDir WindGustSpeed WindDir9am WindDir3pm ... Humidity9am \\\n", "0 SSE 41.0 W SSE ... 95.0 \n", "1 S 41.0 S WSW ... 59.0 \n", "2 SSE 54.0 SSE SE ... 78.0 \n", "3 ESE 39.0 ESE ESE ... 76.0 \n", "4 SE 44.0 SE SE ... 78.0 \n", "\n", " Humidity3pm Pressure9am Pressure3pm Cloud9am Cloud3pm Temp9am \\\n", "0 54.0 1019.5 1017.0 8.0 5.0 18.5 \n", "1 36.0 1015.7 1015.7 NaN NaN 17.7 \n", "2 57.0 1018.7 1016.6 3.0 3.0 20.8 \n", "3 74.0 1021.4 1020.5 7.0 8.0 15.6 \n", "4 70.0 1019.4 1018.9 NaN NaN 17.4 \n", "\n", " Temp3pm RainToday RainTomorrow \n", "0 26.0 Yes 0 \n", "1 21.9 No 0 \n", "2 24.1 Yes 0 \n", "3 16.0 Yes 1 \n", "4 18.1 No 0 \n", "\n", "[5 rows x 22 columns]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load the data\n", "X = pd.read_csv(\"./datasets/weatherAUS.csv\")\n", "\n", "# Let's have a look\n", "X.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run the pipeline" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<< ================== ATOM ================== >>\n", "Algorithm task: binary classification.\n", "\n", "Dataset stats ==================== >>\n", "Shape: (142193, 22)\n", "Train set size: 113755\n", "Test set size: 28438\n", "-------------------------------------\n", "Memory: 61.69 MB\n", "Scaled: False\n", "Missing values: 316559 (10.1%)\n", "Categorical features: 5 (23.8%)\n", "Duplicate samples: 45 (0.0%)\n", "\n", "Fitting Cleaner...\n", "Cleaning the data...\n", "Fitting Imputer...\n", "Imputing missing values...\n", " --> Dropping 161 samples for containing more than 16 missing values.\n", " --> Imputing 481 missing values with median (12.0) in feature MinTemp.\n", " --> Imputing 265 missing values with median (22.6) in feature MaxTemp.\n", " --> Imputing 1354 missing values with median (0.0) in feature Rainfall.\n", " --> Imputing 60682 missing values with median (4.8) in feature Evaporation.\n", " --> Imputing 67659 missing values with median (8.4) in feature Sunshine.\n", " --> Imputing 9187 missing values with most_frequent (W) in feature WindGustDir.\n", " --> Imputing 9127 missing values with median (39.0) in feature WindGustSpeed.\n", " --> Imputing 9852 missing values with most_frequent (N) in feature WindDir9am.\n", " --> Imputing 3617 missing values with most_frequent (SE) in feature WindDir3pm.\n", " --> Imputing 1187 missing values with median (13.0) in feature WindSpeed9am.\n", " --> Imputing 2469 missing values with median (19.0) in feature WindSpeed3pm.\n", " --> Imputing 1613 missing values with median (70.0) in feature Humidity9am.\n", " --> Imputing 3449 missing values with median (52.0) in feature Humidity3pm.\n", " --> Imputing 13863 missing values with median (1017.6) in feature Pressure9am.\n", " --> Imputing 13830 missing values with median (1015.2) in feature Pressure3pm.\n", " --> Imputing 53496 missing values with median (5.0) in feature Cloud9am.\n", " --> Imputing 56933 missing values with median (5.0) in feature Cloud3pm.\n", " --> Imputing 743 missing values with median (16.7) in feature Temp9am.\n", " --> Imputing 2565 missing values with median (21.1) in feature Temp3pm.\n", " --> Imputing 1354 missing values with most_frequent (No) in feature RainToday.\n", "Fitting Encoder...\n", "Encoding categorical columns...\n", " --> LeaveOneOut-encoding feature Location. Contains 49 classes.\n", " --> LeaveOneOut-encoding feature WindGustDir. Contains 16 classes.\n", " --> LeaveOneOut-encoding feature WindDir9am. Contains 16 classes.\n", " --> LeaveOneOut-encoding feature WindDir3pm. Contains 16 classes.\n", " --> Ordinal-encoding feature RainToday. Contains 2 classes.\n" ] } ], "source": [ "# Initialize atom and prepare the data\n", "atom = ATOMClassifier(X, verbose=2, random_state=1)\n", "atom.clean()\n", "atom.impute(strat_num=\"median\", strat_cat=\"most_frequent\", max_nan_rows=0.8)\n", "atom.encode()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Training ========================= >>\n", "Metric: f1\n", "\n", "\n", "Run: 0 =========================== >>\n", "Models: LGB01\n", "Size of training set: 11362 (10%)\n", "Size of test set: 28408\n", "\n", "\n", "Results for LightGBM:\n", "Fit ---------------------------------------------\n", "Train evaluation --> f1: 0.795\n", "Test evaluation --> f1: 0.6169\n", "Time elapsed: 2.702s\n", "Bootstrap ---------------------------------------\n", "Evaluation --> f1: 0.6025 ± 0.0021\n", "Time elapsed: 2.367s\n", "-------------------------------------------------\n", "Total time: 5.069s\n", "\n", "\n", "Final results ==================== >>\n", "Total time: 5.072s\n", "-------------------------------------\n", "LightGBM --> f1: 0.6025 ± 0.0021 ~\n", "\n", "\n", "Run: 1 =========================== >>\n", "Models: LGB02\n", "Size of training set: 22724 (20%)\n", "Size of test set: 28408\n", "\n", "\n", "Results for LightGBM:\n", "Fit ---------------------------------------------\n", "Train evaluation --> f1: 0.711\n", "Test evaluation --> f1: 0.6172\n", "Time elapsed: 3.361s\n", "Bootstrap ---------------------------------------\n", "Evaluation --> f1: 0.606 ± 0.0021\n", "Time elapsed: 2.924s\n", "-------------------------------------------------\n", "Total time: 6.285s\n", "\n", "\n", "Final results ==================== >>\n", "Total time: 6.288s\n", "-------------------------------------\n", "LightGBM --> f1: 0.606 ± 0.0021\n", "\n", "\n", "Run: 2 =========================== >>\n", "Models: LGB03\n", "Size of training set: 34087 (30%)\n", "Size of test set: 28408\n", "\n", "\n", "Results for LightGBM:\n", "Fit ---------------------------------------------\n", "Train evaluation --> f1: 0.6844\n", "Test evaluation --> f1: 0.6205\n", "Time elapsed: 4.115s\n", "Bootstrap ---------------------------------------\n", "Evaluation --> f1: 0.6136 ± 0.0021\n", "Time elapsed: 3.574s\n", "-------------------------------------------------\n", "Total time: 7.689s\n", "\n", "\n", "Final results ==================== >>\n", "Total time: 7.692s\n", "-------------------------------------\n", "LightGBM --> f1: 0.6136 ± 0.0021\n", "\n", "\n", "Run: 3 =========================== >>\n", "Models: LGB04\n", "Size of training set: 45449 (40%)\n", "Size of test set: 28408\n", "\n", "\n", "Results for LightGBM:\n", "Fit ---------------------------------------------\n", "Train evaluation --> f1: 0.6788\n", "Test evaluation --> f1: 0.6246\n", "Time elapsed: 4.704s\n", "Bootstrap ---------------------------------------\n", "Evaluation --> f1: 0.6209 ± 0.0012\n", "Time elapsed: 4.312s\n", "-------------------------------------------------\n", "Total time: 9.017s\n", "\n", "\n", "Final results ==================== >>\n", "Total time: 9.019s\n", "-------------------------------------\n", "LightGBM --> f1: 0.6209 ± 0.0012\n", "\n", "\n", "Run: 4 =========================== >>\n", "Models: LGB05\n", "Size of training set: 56812 (50%)\n", "Size of test set: 28408\n", "\n", "\n", "Results for LightGBM:\n", "Fit ---------------------------------------------\n", "Train evaluation --> f1: 0.6694\n", "Test evaluation --> f1: 0.6256\n", "Time elapsed: 5.333s\n", "Bootstrap ---------------------------------------\n", "Evaluation --> f1: 0.6231 ± 0.0025\n", "Time elapsed: 4.956s\n", "-------------------------------------------------\n", "Total time: 10.289s\n", "\n", "\n", "Final results ==================== >>\n", "Total time: 10.295s\n", "-------------------------------------\n", "LightGBM --> f1: 0.6231 ± 0.0025\n", "\n", "\n", "Run: 5 =========================== >>\n", "Models: LGB06\n", "Size of training set: 68174 (60%)\n", "Size of test set: 28408\n", "\n", "\n", "Results for LightGBM:\n", "Fit ---------------------------------------------\n", "Train evaluation --> f1: 0.6623\n", "Test evaluation --> f1: 0.627\n", "Time elapsed: 6.177s\n", "Bootstrap ---------------------------------------\n", "Evaluation --> f1: 0.6223 ± 0.0043\n", "Time elapsed: 5.432s\n", "-------------------------------------------------\n", "Total time: 11.609s\n", "\n", "\n", "Final results ==================== >>\n", "Total time: 11.615s\n", "-------------------------------------\n", "LightGBM --> f1: 0.6223 ± 0.0043\n", "\n", "\n", "Run: 6 =========================== >>\n", "Models: LGB07\n", "Size of training set: 79536 (70%)\n", "Size of test set: 28408\n", "\n", "\n", "Results for LightGBM:\n", "Fit ---------------------------------------------\n", "Train evaluation --> f1: 0.6609\n", "Test evaluation --> f1: 0.6307\n", "Time elapsed: 6.787s\n", "Bootstrap ---------------------------------------\n", "Evaluation --> f1: 0.6254 ± 0.0029\n", "Time elapsed: 6.138s\n", "-------------------------------------------------\n", "Total time: 12.925s\n", "\n", "\n", "Final results ==================== >>\n", "Total time: 12.930s\n", "-------------------------------------\n", "LightGBM --> f1: 0.6254 ± 0.0029\n", "\n", "\n", "Run: 7 =========================== >>\n", "Models: LGB08\n", "Size of training set: 90899 (80%)\n", "Size of test set: 28408\n", "\n", "\n", "Results for LightGBM:\n", "Fit ---------------------------------------------\n", "Train evaluation --> f1: 0.6588\n", "Test evaluation --> f1: 0.6316\n", "Time elapsed: 7.660s\n", "Bootstrap ---------------------------------------\n", "Evaluation --> f1: 0.6255 ± 0.002\n", "Time elapsed: 7.141s\n", "-------------------------------------------------\n", "Total time: 14.802s\n", "\n", "\n", "Final results ==================== >>\n", "Total time: 14.808s\n", "-------------------------------------\n", "LightGBM --> f1: 0.6255 ± 0.002\n", "\n", "\n", "Run: 8 =========================== >>\n", "Models: LGB09\n", "Size of training set: 102261 (90%)\n", "Size of test set: 28408\n", "\n", "\n", "Results for LightGBM:\n", "Fit ---------------------------------------------\n", "Train evaluation --> f1: 0.6601\n", "Test evaluation --> f1: 0.6318\n", "Time elapsed: 8.433s\n", "Bootstrap ---------------------------------------\n", "Evaluation --> f1: 0.6253 ± 0.0022\n", "Time elapsed: 7.353s\n", "-------------------------------------------------\n", "Total time: 15.786s\n", "\n", "\n", "Final results ==================== >>\n", "Total time: 15.792s\n", "-------------------------------------\n", "LightGBM --> f1: 0.6253 ± 0.0022\n", "\n", "\n", "Run: 9 =========================== >>\n", "Models: LGB10\n", "Size of training set: 113624 (100%)\n", "Size of test set: 28408\n", "\n", "\n", "Results for LightGBM:\n", "Fit ---------------------------------------------\n", "Train evaluation --> f1: 0.6558\n", "Test evaluation --> f1: 0.631\n", "Time elapsed: 8.937s\n", "Bootstrap ---------------------------------------\n", "Evaluation --> f1: 0.6258 ± 0.0034\n", "Time elapsed: 8.158s\n", "-------------------------------------------------\n", "Total time: 17.095s\n", "\n", "\n", "Final results ==================== >>\n", "Total time: 17.100s\n", "-------------------------------------\n", "LightGBM --> f1: 0.6258 ± 0.0034\n" ] } ], "source": [ "# Analyze the impact of the training set's size on a LightGBM model\n", "atom.train_sizing(\"LGB\", train_sizes=10, n_bootstrap=5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Analyze the results" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
score_trainscore_testtime_fitscore_bootstraptime_bootstraptime
fracmodel
0.1LGB010.79500.61692.7019270.6024732.3666295.068556
0.2LGB020.71100.61723.3610560.6059842.9239616.285017
0.3LGB030.68440.62054.1148510.6136333.5738167.688667
0.4LGB040.67880.62464.7044230.6208944.3121119.016534
0.5LGB050.66940.62565.3326240.6230754.95606410.288688
0.6LGB060.66230.62706.1765260.6222875.43217911.608705
0.7LGB070.66090.63076.7866340.6254126.13818312.924817
0.8LGB080.65880.63167.6602430.6255197.14148814.801731
0.9LGB090.66010.63188.4334110.6253347.35263315.786044
1.0LGB100.65580.63108.9372610.6258408.15822217.095483
\n", "
" ], "text/plain": [ " score_train score_test time_fit score_bootstrap \\\n", "frac model \n", "0.1 LGB01 0.7950 0.6169 2.701927 0.602473 \n", "0.2 LGB02 0.7110 0.6172 3.361056 0.605984 \n", "0.3 LGB03 0.6844 0.6205 4.114851 0.613633 \n", "0.4 LGB04 0.6788 0.6246 4.704423 0.620894 \n", "0.5 LGB05 0.6694 0.6256 5.332624 0.623075 \n", "0.6 LGB06 0.6623 0.6270 6.176526 0.622287 \n", "0.7 LGB07 0.6609 0.6307 6.786634 0.625412 \n", "0.8 LGB08 0.6588 0.6316 7.660243 0.625519 \n", "0.9 LGB09 0.6601 0.6318 8.433411 0.625334 \n", "1.0 LGB10 0.6558 0.6310 8.937261 0.625840 \n", "\n", " time_bootstrap time \n", "frac model \n", "0.1 LGB01 2.366629 5.068556 \n", "0.2 LGB02 2.923961 6.285017 \n", "0.3 LGB03 3.573816 7.688667 \n", "0.4 LGB04 4.312111 9.016534 \n", "0.5 LGB05 4.956064 10.288688 \n", "0.6 LGB06 5.432179 11.608705 \n", "0.7 LGB07 6.138183 12.924817 \n", "0.8 LGB08 7.141488 14.801731 \n", "0.9 LGB09 7.352633 15.786044 \n", "1.0 LGB10 8.158222 17.095483 " ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# The results are now multi-index, where frac is the fraction\n", "# of the training set used to fit the model. The model names\n", "# end with the fraction as well (without the dot)\n", "atom.results" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Every model can be accessed through its name\n", "atom.lgb05.plot_shap_waterfall(show=6)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ " \n", " " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.plotly.v1+json": { "config": { "plotlyServerURL": "https://plot.ly" }, "data": [ { "error_y": { "array": [ 0.0021093507865551924, 0.0020567842220873423, 0.002118481499414932, 0.001160167220745696, 0.00250375447453087, 0.004265897967306575, 0.0028864062364942697, 0.0019579716683319, 0.0021689829117167393, 0.003362587796705156 ], "type": "data", "visible": true }, "hovertemplate": "(%{x}, %{y})LGB - f1", "legendgroup": "LGB", "legendgrouptitle": { "font": { "size": 16 }, "text": "LGB" }, "line": { "color": "rgb(0, 98, 98)", "width": 2 }, "marker": { "color": "rgb(0, 98, 98)", "line": { "color": "rgba(255, 255, 255, 0.9)", "width": 1 }, "size": 8, "symbol": "circle" }, "mode": "lines+markers", "name": "f1", "showlegend": true, "type": "scatter", "x": [ 11362, 22724, 34087, 45449, 56812, 68174, 79536, 90899, 102261, 113624 ], "xaxis": "x", "y": [ 0.6024732288039821, 0.605983880000626, 0.6136333528411273, 0.6208941926213539, 0.6230748072286343, 0.62228677254213, 0.6254124354209738, 0.625519241877403, 0.6253336016433961, 0.6258404708878225 ], "yaxis": "y" }, { "hovertemplate": "%{y}upper bound", "legendgroup": "LGB", "line": { "color": "rgb(0, 98, 98)", "width": 1 }, "mode": "lines", "showlegend": false, "type": "scatter", "x": [ 11362, 22724, 34087, 45449, 56812, 68174, 79536, 90899, 102261, 113624 ], "xaxis": "x", "y": [ 0.6045825795905373, 0.6080406642227133, 0.6157518343405423, 0.6220543598420997, 0.6255785617031652, 0.6265526705094366, 0.6282988416574681, 0.6274772135457349, 0.6275025845551129, 0.6292030586845276 ], "yaxis": "y" }, { "fill": "tonexty", "fillcolor": "rgba(0, 98, 98, 0.2)", "hovertemplate": "%{y}lower bound", "legendgroup": "LGB", "line": { "color": "rgb(0, 98, 98)", "width": 1 }, "mode": "lines", "showlegend": false, "type": "scatter", "x": [ 11362, 22724, 34087, 45449, 56812, 68174, 79536, 90899, 102261, 113624 ], "xaxis": "x", "y": [ 0.600363878017427, 0.6039270957785386, 0.6115148713417123, 0.6197340254006082, 0.6205710527541034, 0.6180208745748235, 0.6225260291844795, 0.6235612702090712, 0.6231646187316794, 0.6224778830911173 ], "yaxis": "y" } ], "layout": { "font": { "size": 12 }, "height": 600, "hoverlabel": { "font": { "size": 16 } }, "legend": { "bgcolor": "rgba(255, 255, 255, 0.5)", "font": { "size": 16 }, "groupclick": "togglegroup", "traceorder": "grouped", "x": 0.99, "xanchor": "right", "y": 0.01, "yanchor": "bottom" }, "margin": { "b": 50, "l": 50, "pad": 0, "r": 0, "t": 35 }, "showlegend": true, "template": { "data": { "bar": [ { "error_x": { "color": "#2a3f5f" }, "error_y": { "color": "#2a3f5f" }, "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "bar" } ], "barpolar": [ { "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "barpolar" } ], "carpet": [ { "aaxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "baxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "type": "carpet" } ], "choropleth": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "choropleth" } ], "contour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "contour" } ], "contourcarpet": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "contourcarpet" } ], "heatmap": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmap" } ], "heatmapgl": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmapgl" } ], "histogram": [ { "marker": { "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "histogram" } ], "histogram2d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2d" } ], "histogram2dcontour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2dcontour" } ], "mesh3d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "mesh3d" } ], "parcoords": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "parcoords" } ], "pie": [ { "automargin": true, "type": "pie" } ], "scatter": [ { "fillpattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 }, "type": "scatter" } ], "scatter3d": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatter3d" } ], "scattercarpet": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattercarpet" } ], "scattergeo": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergeo" } ], "scattergl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergl" } ], "scattermapbox": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattermapbox" } ], "scatterpolar": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolar" } ], "scatterpolargl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolargl" } ], "scatterternary": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterternary" } ], "surface": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "surface" } ], "table": [ { "cells": { "fill": { "color": "#EBF0F8" }, "line": { "color": "white" } }, "header": { "fill": { "color": "#C8D4E3" }, "line": { "color": "white" } }, "type": "table" } ] }, "layout": { "annotationdefaults": { "arrowcolor": "#2a3f5f", "arrowhead": 0, "arrowwidth": 1 }, "autotypenumbers": "strict", "coloraxis": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "colorscale": { "diverging": [ [ 0, "#8e0152" ], [ 0.1, "#c51b7d" ], [ 0.2, "#de77ae" ], [ 0.3, "#f1b6da" ], [ 0.4, "#fde0ef" ], [ 0.5, "#f7f7f7" ], [ 0.6, "#e6f5d0" ], [ 0.7, "#b8e186" ], [ 0.8, "#7fbc41" ], [ 0.9, "#4d9221" ], [ 1, "#276419" ] ], "sequential": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "sequentialminus": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ] }, "colorway": [ "#636efa", "#EF553B", "#00cc96", "#ab63fa", "#FFA15A", "#19d3f3", "#FF6692", "#B6E880", "#FF97FF", "#FECB52" ], "font": { "color": "#2a3f5f" }, "geo": { "bgcolor": "white", "lakecolor": "white", "landcolor": "#E5ECF6", "showlakes": true, "showland": true, "subunitcolor": "white" }, "hoverlabel": { "align": "left" }, "hovermode": "closest", "mapbox": { "style": "light" }, "paper_bgcolor": "white", "plot_bgcolor": "#E5ECF6", "polar": { "angularaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "radialaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "scene": { "xaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "yaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "zaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" } }, "shapedefaults": { "line": { "color": "#2a3f5f" } }, "ternary": { "aaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "baxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "caxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "title": { "x": 0.05 }, "xaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 }, "yaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 } } }, "title": { "font": { "size": 24 }, "pad": { "b": 15, "t": 15 }, "x": 0.5, "xanchor": "center", "xref": "paper", "y": 1, "yanchor": "top" }, "width": 900, "xaxis": { "anchor": "y", "automargin": true, "autorange": true, "domain": [ 0, 1 ], "range": [ 4903.27371014096, 120082.72628985904 ], "title": { "font": { "size": 16 }, "text": "Number of training samples" }, "type": "linear" }, "yaxis": { "anchor": "x", "automargin": true, "autorange": true, "domain": [ 0, 1 ], "range": [ 0.5987617013136991, 0.6308052353882555 ], "title": { "font": { "size": 16 }, "text": "Score" }, "type": "linear" } } }, "text/html": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot the train sizing's results\n", "atom.plot_learning_curve()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.2" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 4 }