{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Example: Train sizing\n", "-----------------------\n", "\n", "This example shows how to asses a model's performance based on the size of the training set.\n", "\n", "The data used is a variation on the [Australian weather dataset](https://www.kaggle.com/jsphyg/weather-dataset-rattle-package) from Kaggle. You can download it from [here](https://github.com/tvdboom/ATOM/blob/master/examples/datasets/weatherAUS.csv). The goal of this dataset is to predict whether or not it will rain tomorrow training a binary classifier on target `RainTomorrow`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load the data" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "UserWarning: The pandas version installed (1.5.3) does not match the supported pandas version in Modin (1.5.2). This may cause undesired side effects!\n" ] } ], "source": [ "# Import packages\n", "import pandas as pd\n", "from atom import ATOMClassifier" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | Location | \n", "MinTemp | \n", "MaxTemp | \n", "Rainfall | \n", "Evaporation | \n", "Sunshine | \n", "WindGustDir | \n", "WindGustSpeed | \n", "WindDir9am | \n", "WindDir3pm | \n", "... | \n", "Humidity9am | \n", "Humidity3pm | \n", "Pressure9am | \n", "Pressure3pm | \n", "Cloud9am | \n", "Cloud3pm | \n", "Temp9am | \n", "Temp3pm | \n", "RainToday | \n", "RainTomorrow | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "MelbourneAirport | \n", "18.0 | \n", "26.9 | \n", "21.4 | \n", "7.0 | \n", "8.9 | \n", "SSE | \n", "41.0 | \n", "W | \n", "SSE | \n", "... | \n", "95.0 | \n", "54.0 | \n", "1019.5 | \n", "1017.0 | \n", "8.0 | \n", "5.0 | \n", "18.5 | \n", "26.0 | \n", "Yes | \n", "0 | \n", "
1 | \n", "Adelaide | \n", "17.2 | \n", "23.4 | \n", "0.0 | \n", "NaN | \n", "NaN | \n", "S | \n", "41.0 | \n", "S | \n", "WSW | \n", "... | \n", "59.0 | \n", "36.0 | \n", "1015.7 | \n", "1015.7 | \n", "NaN | \n", "NaN | \n", "17.7 | \n", "21.9 | \n", "No | \n", "0 | \n", "
2 | \n", "Cairns | \n", "18.6 | \n", "24.6 | \n", "7.4 | \n", "3.0 | \n", "6.1 | \n", "SSE | \n", "54.0 | \n", "SSE | \n", "SE | \n", "... | \n", "78.0 | \n", "57.0 | \n", "1018.7 | \n", "1016.6 | \n", "3.0 | \n", "3.0 | \n", "20.8 | \n", "24.1 | \n", "Yes | \n", "0 | \n", "
3 | \n", "Portland | \n", "13.6 | \n", "16.8 | \n", "4.2 | \n", "1.2 | \n", "0.0 | \n", "ESE | \n", "39.0 | \n", "ESE | \n", "ESE | \n", "... | \n", "76.0 | \n", "74.0 | \n", "1021.4 | \n", "1020.5 | \n", "7.0 | \n", "8.0 | \n", "15.6 | \n", "16.0 | \n", "Yes | \n", "1 | \n", "
4 | \n", "Walpole | \n", "16.4 | \n", "19.9 | \n", "0.0 | \n", "NaN | \n", "NaN | \n", "SE | \n", "44.0 | \n", "SE | \n", "SE | \n", "... | \n", "78.0 | \n", "70.0 | \n", "1019.4 | \n", "1018.9 | \n", "NaN | \n", "NaN | \n", "17.4 | \n", "18.1 | \n", "No | \n", "0 | \n", "
5 rows × 22 columns
\n", "\n", " | \n", " | score_train | \n", "score_test | \n", "time_fit | \n", "score_bootstrap | \n", "time_bootstrap | \n", "time | \n", "
---|---|---|---|---|---|---|---|
frac | \n", "model | \n", "\n", " | \n", " | \n", " | \n", " | \n", " | \n", " |
0.1 | \n", "LGB01 | \n", "0.7950 | \n", "0.6169 | \n", "2.701927 | \n", "0.602473 | \n", "2.366629 | \n", "5.068556 | \n", "
0.2 | \n", "LGB02 | \n", "0.7110 | \n", "0.6172 | \n", "3.361056 | \n", "0.605984 | \n", "2.923961 | \n", "6.285017 | \n", "
0.3 | \n", "LGB03 | \n", "0.6844 | \n", "0.6205 | \n", "4.114851 | \n", "0.613633 | \n", "3.573816 | \n", "7.688667 | \n", "
0.4 | \n", "LGB04 | \n", "0.6788 | \n", "0.6246 | \n", "4.704423 | \n", "0.620894 | \n", "4.312111 | \n", "9.016534 | \n", "
0.5 | \n", "LGB05 | \n", "0.6694 | \n", "0.6256 | \n", "5.332624 | \n", "0.623075 | \n", "4.956064 | \n", "10.288688 | \n", "
0.6 | \n", "LGB06 | \n", "0.6623 | \n", "0.6270 | \n", "6.176526 | \n", "0.622287 | \n", "5.432179 | \n", "11.608705 | \n", "
0.7 | \n", "LGB07 | \n", "0.6609 | \n", "0.6307 | \n", "6.786634 | \n", "0.625412 | \n", "6.138183 | \n", "12.924817 | \n", "
0.8 | \n", "LGB08 | \n", "0.6588 | \n", "0.6316 | \n", "7.660243 | \n", "0.625519 | \n", "7.141488 | \n", "14.801731 | \n", "
0.9 | \n", "LGB09 | \n", "0.6601 | \n", "0.6318 | \n", "8.433411 | \n", "0.625334 | \n", "7.352633 | \n", "15.786044 | \n", "
1.0 | \n", "LGB10 | \n", "0.6558 | \n", "0.6310 | \n", "8.937261 | \n", "0.625840 | \n", "8.158222 | \n", "17.095483 | \n", "