{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# NewsClassifier" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "mtVYEQSYsswc", "outputId": "6f16c0c1-ef25-406c-dd14-edd1a72dc760", "trusted": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[nltk_data] Downloading package stopwords to\n", "[nltk_data] C:\\Users\\manis\\AppData\\Roaming\\nltk_data...\n", "[nltk_data] Package stopwords is already up-to-date!\n" ] }, { "data": { "text/plain": [ "True" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Imports\n", "import os\n", "import gc\n", "import time\n", "from pathlib import Path\n", "import json\n", "from typing import Tuple, Dict\n", "from warnings import filterwarnings\n", "\n", "filterwarnings(\"ignore\")\n", "\n", "import pandas as pd\n", "import numpy as np\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score\n", "\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import ipywidgets as widgets\n", "from wordcloud import WordCloud, STOPWORDS\n", "\n", "from tqdm.auto import tqdm\n", "from dataclasses import dataclass\n", "\n", "import re\n", "import nltk\n", "from nltk.corpus import stopwords\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader, Dataset\n", "\n", "from transformers import RobertaTokenizer, RobertaModel\n", "\n", "import wandb\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "nltk.download(\"stopwords\")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "trusted": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmanishdrw1\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" ] }, { "data": { "text/plain": [ "True" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "wandb.login()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "fGW_WYn31JHT", "trusted": true }, "outputs": [], "source": [ "@dataclass\n", "class Cfg:\n", " STOPWORDS = stopwords.words(\"english\")\n", " dataset_loc = \"../dataset/raw/news_dataset.csv\"\n", " test_size = 0.2\n", "\n", " add_special_tokens = True\n", " max_len = 50\n", " pad_to_max_length = True\n", " truncation = True\n", "\n", " change_config = False\n", "\n", " dropout_pb = 0.5\n", " lr = 1e-4\n", " lr_redfactor = 0.7\n", " lr_redpatience = 4\n", " epochs = 10\n", " batch_size = 128\n", "\n", " wandb_sweep = False" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "7V5OJWw4sswg", "outputId": "8eb13263-d31a-4d49-f1f6-3c2dc0595c78", "trusted": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Matthew McConaughey Gives Joy Behar A Foot Massage On ‘The View’\n", "Entertainment\n" ] } ], "source": [ "df = pd.read_csv(Cfg.dataset_loc)\n", "print(df[\"Title\"][10040])\n", "print(df[\"Category\"][10040])" ] }, { "cell_type": "markdown", "metadata": { "id": "w05pkO5RN1H2" }, "source": [ "## Prepare Data" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "id": "l8Z3Hhk3sswg", "trusted": true }, "outputs": [], "source": [ "def prepare_data(df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame]:\n", " \"\"\"Separate headlines instance and feature selection.\n", "\n", " Args:\n", " df: original dataframe.\n", "\n", " Returns:\n", " df: new dataframe with appropriate features.\n", " headlines_df: dataframe cintaining \"headlines\" category instances.\n", " \"\"\"\n", " df = df[[\"Title\", \"Category\"]]\n", " df.rename(columns={\"Title\": \"Text\"}, inplace=True)\n", " df, headlines_df = df[df[\"Category\"] != \"Headlines\"].reset_index(drop=True), df[df[\"Category\"] == \"Headlines\"].reset_index(drop=True)\n", "\n", " return df, headlines_df" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "id": "d4t7JjIEsswg", "trusted": true }, "outputs": [], "source": [ "def clean_text(text: str) -> str:\n", " \"\"\"Clean text (lower, puntuations removal, blank space removal).\"\"\"\n", " # lower case the text\n", " text = text.lower() # necessary to do before as stopwords are in lower case\n", "\n", " # remove stopwords\n", " stp_pattern = re.compile(r\"\\b(\" + r\"|\".join(Cfg.STOPWORDS) + r\")\\b\\s*\")\n", " text = stp_pattern.sub(\"\", text)\n", "\n", " # custom cleaning\n", " text = text.strip() # remove space at start or end if any\n", " text = re.sub(\" +\", \" \", text) # remove extra spaces\n", " text = re.sub(\"[^A-Za-z0-9]+\", \" \", text) # remove characters that are not alphanumeric\n", "\n", " return text" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "id": "NokmvVFusswh", "trusted": true }, "outputs": [], "source": [ "def preprocess(df: pd.DataFrame) -> Tuple[pd.DataFrame, Dict, Dict]:\n", " \"\"\"Preprocess the data.\n", "\n", " Args:\n", " df: Dataframe on which the preprocessing steps need to be performed.\n", "\n", " Returns:\n", " df: Preprocessed Data.\n", " class_to_index: class labels to indices mapping\n", " class_to_index: indices to class labels mapping\n", " \"\"\"\n", " df, headlines_df = prepare_data(df)\n", "\n", " cats = df[\"Category\"].unique().tolist()\n", " num_classes = len(cats)\n", " class_to_index = {tag: i for i, tag in enumerate(cats)}\n", " index_to_class = {v: k for k, v in class_to_index.items()}\n", "\n", " df[\"Text\"] = df[\"Text\"].apply(clean_text) # clean text\n", " df = df[[\"Text\", \"Category\"]]\n", " df[\"Category\"] = df[\"Category\"].map(class_to_index) # label encoding\n", " return df, class_to_index, index_to_class" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "id": "f45cNikCsswh", "outputId": "880e338e-11a3-4048-ccf7-d30bf13e996b", "trusted": true }, "outputs": [ { "data": { "text/html": [ "
\n", " | Text | \n", "Category | \n", "
---|---|---|
0 | \n", "chainlink link falters hedera hbar wobbles yet... | \n", "0 | \n", "
1 | \n", "funds punished owning nvidia shares stunning 2... | \n", "0 | \n", "
2 | \n", "crude oil prices stalled hedge funds sold kemp | \n", "0 | \n", "
3 | \n", "grayscale bitcoin win still half battle | \n", "0 | \n", "
4 | \n", "home shopping editor miss labor day deals eyeing | \n", "0 | \n", "
... | \n", "... | \n", "... | \n", "
44142 | \n", "slovakia election could echo ukraine expect | \n", "6 | \n", "
44143 | \n", "things know nobel prizes washington post | \n", "6 | \n", "
44144 | \n", "brief calm protests killing 2 students rock im... | \n", "6 | \n", "
44145 | \n", "one safe france vows action bedbugs sweep paris | \n", "6 | \n", "
44146 | \n", "slovakia election polls open knife edge vote u... | \n", "6 | \n", "
44147 rows × 2 columns
\n", "