{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Example: NLP\n", "--------------\n", "\n", "This example shows how to use ATOM to quickly go from raw text data to model predictions.\n", "\n", "Import the 20 newsgroups text dataset from [sklearn.datasets](https://scikit-learn.org/0.19/datasets/twenty_newsgroups.html). The dataset comprises around 18000 articles on 20 topics. The goal is to predict the topic of every article." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load the data" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from atom import ATOMClassifier\n", "from sklearn.datasets import fetch_20newsgroups" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Use only a subset of the available topics for faster processing\n", "X_text, y_text = fetch_20newsgroups(\n", " return_X_y=True,\n", " categories=[\n", " 'sci.med',\n", " 'comp.windows.x',\n", " 'misc.forsale',\n", " 'rec.autos',\n", " ],\n", " shuffle=True,\n", " random_state=1,\n", ")\n", "X_text = np.array(X_text).reshape(-1, 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run the pipeline" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<< ================== ATOM ================== >>\n", "Algorithm task: multiclass classification.\n", "\n", "Dataset stats ==================== >>\n", "Shape: (2366, 2)\n", "Memory: 4.07 MB\n", "Scaled: False\n", "Categorical features: 1 (100.0%)\n", "-------------------------------------\n", "Train set size: 1657\n", "Test set size: 709\n", "-------------------------------------\n", "| | dataset | train | test |\n", "| - | ----------- | ----------- | ----------- |\n", "| 0 | 593 (1.0) | 415 (1.0) | 178 (1.0) |\n", "| 1 | 585 (1.0) | 410 (1.0) | 175 (1.0) |\n", "| 2 | 594 (1.0) | 416 (1.0) | 178 (1.0) |\n", "| 3 | 594 (1.0) | 416 (1.0) | 178 (1.0) |\n", "\n" ] } ], "source": [ "atom = ATOMClassifier(X_text, y_text, index=True, test_size=0.3, verbose=2, random_state=1)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | corpus | \n", "target | \n", "
---|---|---|
1731 | \n", "From: rlm@helen.surfcty.com (Robert L. McMilli... | \n", "0 | \n", "
1496 | \n", "From: carl@SOL1.GPS.CALTECH.EDU (Carl J Lydick... | \n", "3 | \n", "
1290 | \n", "From: thssjxy@iitmax.iit.edu (Smile)\\nSubject:... | \n", "1 | \n", "
2021 | \n", "From: c23st@kocrsv01.delcoelect.com (Spiros Tr... | \n", "2 | \n", "
142 | \n", "From: ginkgo@ecsvax.uncecs.edu (J. Geary Morto... | \n", "1 | \n", "
... | \n", "... | \n", "... | \n", "
510 | \n", "From: mary@uicsl.csl.uiuc.edu (Mary E. Allison... | \n", "3 | \n", "
1948 | \n", "From: ndd@sunbar.mc.duke.edu (Ned Danieley)\\nS... | \n", "0 | \n", "
798 | \n", "From: kk@unisql.UUCP (Kerry Kimbrough)\\nSubjec... | \n", "0 | \n", "
2222 | \n", "From: hamachi@adobe.com (Gordon Hamachi)\\nSubj... | \n", "2 | \n", "
2215 | \n", "From: mobasser@vu-vlsi.ee.vill.edu (Bijan Moba... | \n", "2 | \n", "
2366 rows × 2 columns
\n", "\n", " | url | \n", "html | \n", "emoji | \n", "number | \n", "|
---|---|---|---|---|---|
1731 | \n", "[rlm@helen.surfcty.com, rlm@helen.surfcty.com] | \n", "NaN | \n", "[<std.disclaimer.h>] | \n", "NaN | \n", "[8] | \n", "
1496 | \n", "[carl@sol1.gps.caltech.edu, carl@sol1.gps.calt... | \n", "NaN | \n", "[<>] | \n", "NaN | \n", "[28] | \n", "
1290 | \n", "[thssjxy@iitmax.iit.edu, thssjxy@iitmax.acc.ii... | \n", "NaN | \n", "NaN | \n", "NaN | \n", "[223158, 15645, 14, 80, 150] | \n", "
2021 | \n", "[c23st@kocrsv01.delcoelect.com, c4wjgq.a40@con... | \n", "NaN | \n", "[<>] | \n", "NaN | \n", "[10, 21, 6, 317, 451, 0815, 46904] | \n", "
142 | \n", "[ginkgo@ecsvax.uncecs.edu, ginkgo@uncecs.edu] | \n", "NaN | \n", "[<>] | \n", "NaN | \n", "[95, 17, 95, 95, 95, 100, 00, 919, 851, 6565, ... | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
403 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "[223, 250, 10, 8, 8, 2002, 1600] | \n", "
1634 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "[15, 1, 1] | \n", "
1262 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "[38, 84] | \n", "
1360 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "[27, 15, 27, 225, 250, 412, 624, 6115, 371, 0154] | \n", "
211 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "[13, 93, 212, 274, 0646, 1097, 08836, 908, 563... | \n", "
2366 rows × 5 columns
\n", "