{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Example: Train sizing\n", "-----------------------\n", "\n", "This example shows how to asses a model's performance based on the size of the training set.\n", "\n", "The data used is a variation on the [Australian weather dataset](https://www.kaggle.com/jsphyg/weather-dataset-rattle-package) from Kaggle. You can download it from [here](https://github.com/tvdboom/ATOM/blob/master/examples/datasets/weatherAUS.csv). The goal of this dataset is to predict whether or not it will rain tomorrow training a binary classifier on target `RainTomorrow`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load the data" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# Import packages\n", "import pandas as pd\n", "from atom import ATOMClassifier" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | Location | \n", "MinTemp | \n", "MaxTemp | \n", "Rainfall | \n", "Evaporation | \n", "Sunshine | \n", "WindGustDir | \n", "WindGustSpeed | \n", "WindDir9am | \n", "WindDir3pm | \n", "... | \n", "Humidity9am | \n", "Humidity3pm | \n", "Pressure9am | \n", "Pressure3pm | \n", "Cloud9am | \n", "Cloud3pm | \n", "Temp9am | \n", "Temp3pm | \n", "RainToday | \n", "RainTomorrow | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "MelbourneAirport | \n", "18.0 | \n", "26.9 | \n", "21.4 | \n", "7.0 | \n", "8.9 | \n", "SSE | \n", "41.0 | \n", "W | \n", "SSE | \n", "... | \n", "95.0 | \n", "54.0 | \n", "1019.5 | \n", "1017.0 | \n", "8.0 | \n", "5.0 | \n", "18.5 | \n", "26.0 | \n", "Yes | \n", "0 | \n", "
1 | \n", "Adelaide | \n", "17.2 | \n", "23.4 | \n", "0.0 | \n", "NaN | \n", "NaN | \n", "S | \n", "41.0 | \n", "S | \n", "WSW | \n", "... | \n", "59.0 | \n", "36.0 | \n", "1015.7 | \n", "1015.7 | \n", "NaN | \n", "NaN | \n", "17.7 | \n", "21.9 | \n", "No | \n", "0 | \n", "
2 | \n", "Cairns | \n", "18.6 | \n", "24.6 | \n", "7.4 | \n", "3.0 | \n", "6.1 | \n", "SSE | \n", "54.0 | \n", "SSE | \n", "SE | \n", "... | \n", "78.0 | \n", "57.0 | \n", "1018.7 | \n", "1016.6 | \n", "3.0 | \n", "3.0 | \n", "20.8 | \n", "24.1 | \n", "Yes | \n", "0 | \n", "
3 | \n", "Portland | \n", "13.6 | \n", "16.8 | \n", "4.2 | \n", "1.2 | \n", "0.0 | \n", "ESE | \n", "39.0 | \n", "ESE | \n", "ESE | \n", "... | \n", "76.0 | \n", "74.0 | \n", "1021.4 | \n", "1020.5 | \n", "7.0 | \n", "8.0 | \n", "15.6 | \n", "16.0 | \n", "Yes | \n", "1 | \n", "
4 | \n", "Walpole | \n", "16.4 | \n", "19.9 | \n", "0.0 | \n", "NaN | \n", "NaN | \n", "SE | \n", "44.0 | \n", "SE | \n", "SE | \n", "... | \n", "78.0 | \n", "70.0 | \n", "1019.4 | \n", "1018.9 | \n", "NaN | \n", "NaN | \n", "17.4 | \n", "18.1 | \n", "No | \n", "0 | \n", "
5 rows × 22 columns
\n", "\n", " | \n", " | f1_train | \n", "f1_test | \n", "time_fit | \n", "f1_bootstrap | \n", "time_bootstrap | \n", "time | \n", "
---|---|---|---|---|---|---|---|
frac | \n", "model | \n", "\n", " | \n", " | \n", " | \n", " | \n", " | \n", " |
0.100000 | \n", "LR01 | \n", "0.562100 | \n", "0.584800 | \n", "1.181076 | \n", "0.584922 | \n", "0.909830 | \n", "2.090906 | \n", "
0.200000 | \n", "LR02 | \n", "0.583200 | \n", "0.584600 | \n", "1.455324 | \n", "0.585234 | \n", "1.120021 | \n", "2.575345 | \n", "
0.300000 | \n", "LR03 | \n", "0.580000 | \n", "0.585200 | \n", "1.702020 | \n", "0.586118 | \n", "1.354517 | \n", "3.056537 | \n", "
0.400000 | \n", "LR04 | \n", "0.584500 | \n", "0.585700 | \n", "2.250048 | \n", "0.586348 | \n", "1.599457 | \n", "3.849505 | \n", "
0.500000 | \n", "LR05 | \n", "0.583300 | \n", "0.586500 | \n", "2.163214 | \n", "0.585384 | \n", "1.877947 | \n", "4.041161 | \n", "
0.600000 | \n", "LR06 | \n", "0.583100 | \n", "0.583200 | \n", "2.338079 | \n", "0.584891 | \n", "1.898731 | \n", "4.236810 | \n", "
0.700000 | \n", "LR07 | \n", "0.587800 | \n", "0.585800 | \n", "2.426779 | \n", "0.585235 | \n", "2.059590 | \n", "4.486369 | \n", "
0.800000 | \n", "LR08 | \n", "0.591600 | \n", "0.588600 | \n", "2.630608 | \n", "0.585269 | \n", "2.172981 | \n", "4.803589 | \n", "
0.900000 | \n", "LR09 | \n", "0.585600 | \n", "0.583300 | \n", "2.836993 | \n", "0.584633 | \n", "2.550147 | \n", "5.387140 | \n", "
1.000000 | \n", "LR10 | \n", "0.585800 | \n", "0.584800 | \n", "4.211031 | \n", "0.584836 | \n", "2.966612 | \n", "7.177643 | \n", "