{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Example: Train sizing\n", "-----------------------\n", "\n", "This example shows how to asses a model's performance based on the size of the training set.\n", "\n", "The data used is a variation on the [Australian weather dataset](https://www.kaggle.com/jsphyg/weather-dataset-rattle-package) from Kaggle. You can download it from [here](https://github.com/tvdboom/ATOM/blob/master/examples/datasets/weatherAUS.csv). The goal of this dataset is to predict whether or not it will rain tomorrow training a binary classifier on target `RainTomorrow`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load the data" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# Import packages\n", "import pandas as pd\n", "from atom import ATOMClassifier" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | Location | \n", "MinTemp | \n", "MaxTemp | \n", "Rainfall | \n", "Evaporation | \n", "Sunshine | \n", "WindGustDir | \n", "WindGustSpeed | \n", "WindDir9am | \n", "WindDir3pm | \n", "... | \n", "Humidity9am | \n", "Humidity3pm | \n", "Pressure9am | \n", "Pressure3pm | \n", "Cloud9am | \n", "Cloud3pm | \n", "Temp9am | \n", "Temp3pm | \n", "RainToday | \n", "RainTomorrow | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "MelbourneAirport | \n", "18.0 | \n", "26.9 | \n", "21.4 | \n", "7.0 | \n", "8.9 | \n", "SSE | \n", "41.0 | \n", "W | \n", "SSE | \n", "... | \n", "95.0 | \n", "54.0 | \n", "1019.5 | \n", "1017.0 | \n", "8.0 | \n", "5.0 | \n", "18.5 | \n", "26.0 | \n", "Yes | \n", "0 | \n", "
1 | \n", "Adelaide | \n", "17.2 | \n", "23.4 | \n", "0.0 | \n", "NaN | \n", "NaN | \n", "S | \n", "41.0 | \n", "S | \n", "WSW | \n", "... | \n", "59.0 | \n", "36.0 | \n", "1015.7 | \n", "1015.7 | \n", "NaN | \n", "NaN | \n", "17.7 | \n", "21.9 | \n", "No | \n", "0 | \n", "
2 | \n", "Cairns | \n", "18.6 | \n", "24.6 | \n", "7.4 | \n", "3.0 | \n", "6.1 | \n", "SSE | \n", "54.0 | \n", "SSE | \n", "SE | \n", "... | \n", "78.0 | \n", "57.0 | \n", "1018.7 | \n", "1016.6 | \n", "3.0 | \n", "3.0 | \n", "20.8 | \n", "24.1 | \n", "Yes | \n", "0 | \n", "
3 | \n", "Portland | \n", "13.6 | \n", "16.8 | \n", "4.2 | \n", "1.2 | \n", "0.0 | \n", "ESE | \n", "39.0 | \n", "ESE | \n", "ESE | \n", "... | \n", "76.0 | \n", "74.0 | \n", "1021.4 | \n", "1020.5 | \n", "7.0 | \n", "8.0 | \n", "15.6 | \n", "16.0 | \n", "Yes | \n", "1 | \n", "
4 | \n", "Walpole | \n", "16.4 | \n", "19.9 | \n", "0.0 | \n", "NaN | \n", "NaN | \n", "SE | \n", "44.0 | \n", "SE | \n", "SE | \n", "... | \n", "78.0 | \n", "70.0 | \n", "1019.4 | \n", "1018.9 | \n", "NaN | \n", "NaN | \n", "17.4 | \n", "18.1 | \n", "No | \n", "0 | \n", "
5 rows × 22 columns
\n", "\n", " | \n", " | score_train | \n", "score_test | \n", "time_fit | \n", "score_bootstrap | \n", "time_bootstrap | \n", "time | \n", "
---|---|---|---|---|---|---|---|
frac | \n", "model | \n", "\n", " | \n", " | \n", " | \n", " | \n", " | \n", " |
0.1 | \n", "LGB01 | \n", "0.7950 | \n", "0.6169 | \n", "2.726477 | \n", "0.602473 | \n", "2.361145 | \n", "5.087622 | \n", "
0.2 | \n", "LGB02 | \n", "0.7110 | \n", "0.6172 | \n", "3.587786 | \n", "0.605984 | \n", "3.214102 | \n", "6.801888 | \n", "
0.3 | \n", "LGB03 | \n", "0.6844 | \n", "0.6205 | \n", "4.144765 | \n", "0.613633 | \n", "3.724748 | \n", "7.869513 | \n", "
0.4 | \n", "LGB04 | \n", "0.6788 | \n", "0.6246 | \n", "4.740403 | \n", "0.620894 | \n", "4.360960 | \n", "9.101363 | \n", "
0.5 | \n", "LGB05 | \n", "0.6694 | \n", "0.6256 | \n", "5.559976 | \n", "0.623075 | \n", "5.128658 | \n", "10.688634 | \n", "
0.6 | \n", "LGB06 | \n", "0.6623 | \n", "0.6270 | \n", "6.234684 | \n", "0.622287 | \n", "5.758230 | \n", "11.992914 | \n", "
0.7 | \n", "LGB07 | \n", "0.6609 | \n", "0.6307 | \n", "6.979477 | \n", "0.625412 | \n", "6.485406 | \n", "13.464883 | \n", "
0.8 | \n", "LGB08 | \n", "0.6588 | \n", "0.6316 | \n", "7.868586 | \n", "0.625519 | \n", "7.226822 | \n", "15.095408 | \n", "
0.9 | \n", "LGB09 | \n", "0.6601 | \n", "0.6318 | \n", "8.578300 | \n", "0.625334 | \n", "8.168814 | \n", "16.747114 | \n", "
1.0 | \n", "LGB10 | \n", "0.6558 | \n", "0.6310 | \n", "9.401000 | \n", "0.625840 | \n", "8.782370 | \n", "18.183370 | \n", "