{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Example: Train sizing\n", "-----------------------\n", "\n", "This example shows how to asses a model's performance based on the size of the training set.\n", "\n", "The data used is a variation on the [Australian weather dataset](https://www.kaggle.com/jsphyg/weather-dataset-rattle-package) from Kaggle. You can download it from [here](https://github.com/tvdboom/ATOM/blob/master/examples/datasets/weatherAUS.csv). The goal of this dataset is to predict whether or not it will rain tomorrow training a binary classifier on target `RainTomorrow`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load the data" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# Import packages\n", "import pandas as pd\n", "from atom import ATOMClassifier" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | Location | \n", "MinTemp | \n", "MaxTemp | \n", "Rainfall | \n", "Evaporation | \n", "Sunshine | \n", "WindGustDir | \n", "WindGustSpeed | \n", "WindDir9am | \n", "WindDir3pm | \n", "... | \n", "Humidity9am | \n", "Humidity3pm | \n", "Pressure9am | \n", "Pressure3pm | \n", "Cloud9am | \n", "Cloud3pm | \n", "Temp9am | \n", "Temp3pm | \n", "RainToday | \n", "RainTomorrow | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "MelbourneAirport | \n", "18.0 | \n", "26.9 | \n", "21.4 | \n", "7.0 | \n", "8.9 | \n", "SSE | \n", "41.0 | \n", "W | \n", "SSE | \n", "... | \n", "95.0 | \n", "54.0 | \n", "1019.5 | \n", "1017.0 | \n", "8.0 | \n", "5.0 | \n", "18.5 | \n", "26.0 | \n", "Yes | \n", "0 | \n", "
1 | \n", "Adelaide | \n", "17.2 | \n", "23.4 | \n", "0.0 | \n", "NaN | \n", "NaN | \n", "S | \n", "41.0 | \n", "S | \n", "WSW | \n", "... | \n", "59.0 | \n", "36.0 | \n", "1015.7 | \n", "1015.7 | \n", "NaN | \n", "NaN | \n", "17.7 | \n", "21.9 | \n", "No | \n", "0 | \n", "
2 | \n", "Cairns | \n", "18.6 | \n", "24.6 | \n", "7.4 | \n", "3.0 | \n", "6.1 | \n", "SSE | \n", "54.0 | \n", "SSE | \n", "SE | \n", "... | \n", "78.0 | \n", "57.0 | \n", "1018.7 | \n", "1016.6 | \n", "3.0 | \n", "3.0 | \n", "20.8 | \n", "24.1 | \n", "Yes | \n", "0 | \n", "
3 | \n", "Portland | \n", "13.6 | \n", "16.8 | \n", "4.2 | \n", "1.2 | \n", "0.0 | \n", "ESE | \n", "39.0 | \n", "ESE | \n", "ESE | \n", "... | \n", "76.0 | \n", "74.0 | \n", "1021.4 | \n", "1020.5 | \n", "7.0 | \n", "8.0 | \n", "15.6 | \n", "16.0 | \n", "Yes | \n", "1 | \n", "
4 | \n", "Walpole | \n", "16.4 | \n", "19.9 | \n", "0.0 | \n", "NaN | \n", "NaN | \n", "SE | \n", "44.0 | \n", "SE | \n", "SE | \n", "... | \n", "78.0 | \n", "70.0 | \n", "1019.4 | \n", "1018.9 | \n", "NaN | \n", "NaN | \n", "17.4 | \n", "18.1 | \n", "No | \n", "0 | \n", "
5 rows × 22 columns
\n", "\n", " | \n", " | f1_train | \n", "f1_test | \n", "time_fit | \n", "f1_bootstrap | \n", "time_bootstrap | \n", "time | \n", "
---|---|---|---|---|---|---|---|
frac | \n", "model | \n", "\n", " | \n", " | \n", " | \n", " | \n", " | \n", " |
0.1 | \n", "LR01 | \n", "0.5621 | \n", "0.5848 | \n", "1.200989 | \n", "0.584922 | \n", "0.909424 | \n", "2.110413 | \n", "
0.2 | \n", "LR02 | \n", "0.5832 | \n", "0.5846 | \n", "1.238803 | \n", "0.585234 | \n", "1.076237 | \n", "2.315040 | \n", "
0.3 | \n", "LR03 | \n", "0.5800 | \n", "0.5852 | \n", "1.426356 | \n", "0.586118 | \n", "1.456787 | \n", "2.883143 | \n", "
0.4 | \n", "LR04 | \n", "0.5845 | \n", "0.5857 | \n", "1.501177 | \n", "0.586348 | \n", "1.585152 | \n", "3.086329 | \n", "
0.5 | \n", "LR05 | \n", "0.5833 | \n", "0.5865 | \n", "1.672944 | \n", "0.585384 | \n", "1.635141 | \n", "3.308085 | \n", "
0.6 | \n", "LR06 | \n", "0.5831 | \n", "0.5832 | \n", "1.865011 | \n", "0.584891 | \n", "2.217760 | \n", "4.082771 | \n", "
0.7 | \n", "LR07 | \n", "0.5878 | \n", "0.5858 | \n", "2.301655 | \n", "0.585235 | \n", "2.969447 | \n", "5.271102 | \n", "
0.8 | \n", "LR08 | \n", "0.5916 | \n", "0.5886 | \n", "5.181574 | \n", "0.585269 | \n", "4.662252 | \n", "9.843826 | \n", "
0.9 | \n", "LR09 | \n", "0.5856 | \n", "0.5833 | \n", "7.434335 | \n", "0.584633 | \n", "7.262542 | \n", "14.696877 | \n", "
1.0 | \n", "LR10 | \n", "0.5858 | \n", "0.5848 | \n", "7.974012 | \n", "0.584836 | \n", "6.985089 | \n", "14.959101 | \n", "