{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import nltk\n", "import torch\n", "import pandas as pd\n", "import numpy as np\n", "from glob import glob\n", "from nltk import sent_tokenize\n", "from transformers import pipeline" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "nltk.download('punkt')\n", "nltk.download('stopwords')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "device = 0 if torch.cuda.is_available() else \"cpu\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__Load Model__" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = \"facebook/bart-large-mnli\"\n", "clf = pipeline(\"zero-shot-classification\", \n", " model=model, \n", " device=device)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test = \"I like your phone, does it even work?\"\n", "classes = [\"Love\", \"Appreciation\", \"Sarcasm\", \"Anger\", \"Hunger\", \"Dialogue\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "clf(test, classes, multi_label=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__Load Dataset__" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "subs = glob(\"../data/subs/*.srt\")\n", "subs[:5]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Understanding Data.\n", "with open(subs[0], \"r\", encoding=\"utf-8\") as f:\n", " con = f.read()\n", " print(con[:150])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with open(subs[0], \"r\", encoding=\"utf-8\") as f:\n", " lines = f.readlines()\n", " cnt = 0\n", " con = []\n", " for line in lines:\n", " line = line.strip()\n", " if line.isnumeric() or \"-->\" in line:\n", " cnt += 1\n", " else:\n", " con.append(line)\n", "\n", "print(f\"Ignored {cnt} lines out of {len(lines)}. Total lines {len(con)} now.\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Episode\n", "print(subs[0])\n", "subs[0].split(\"-\")[1].strip()[-1]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def load_subs():\n", " subs = glob(\"../data/subs/*.srt\")\n", " episodes = []\n", " scripts = []\n", "\n", " for sub in subs:\n", " with open(sub, \"r\", encoding=\"utf-8\") as f:\n", " lines = f.readlines()\n", " cnt = 0\n", " con = []\n", " for line in lines:\n", " line = line.strip()\n", " if line.isnumeric() or \"-->\" in line:\n", " cnt += 1\n", " else:\n", " con.append(line)\n", " \n", " script = \" \".join(con)\n", " epno = int(sub.split(\"-\")[1].strip()[-1])\n", " episodes.append(epno)\n", " scripts.append(script)\n", "\n", " df = pd.DataFrame({\"episode\": episodes, \"script\": scripts})\n", " return df" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df = load_subs()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__Model Testing__" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "script = df.iloc[0][\"script\"]\n", "script" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "script_sentences = sent_tokenize(script)\n", "script_sentences[:3]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Batch sentences\n", "sentence_batch_size = 20\n", "script_batches = []\n", "\n", "for index in range(0, len(script_sentences), sentence_batch_size):\n", " script_batches.append(\"\".join(script_sentences[index:index + sentence_batch_size]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "len(script_batches)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "theme_output = clf(\n", " script_batches[:2],\n", " classes,\n", " multi_label=True\n", ")\n", "\n", "theme_output" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "themes = {}\n", "for output in theme_output:\n", " for label, score in zip(output[\"labels\"], output[\"scores\"]):\n", " if label not in themes:\n", " themes[label] = []\n", " themes[label].append(score)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "themes" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_theme_inference(script):\n", "\n", " classes = [\"Sarcasm\", \"Happy\", \"Friendship\", \"Vulgar\", \"Anger\", \"Dialogue\", \"Sad\", \"Love\", \"Narration\"]\n", " script_sentences = sent_tokenize(script)\n", " sentence_batch_size = 20\n", " script_batches = []\n", " for index in range(0, len(script_sentences), sentence_batch_size):\n", " script_batches.append(\"\".join(script_sentences[index:index + sentence_batch_size]))\n", "\n", " theme_output = clf(\n", " script_batches,\n", " classes,\n", " multi_label=True\n", " )\n", "\n", " themes = {}\n", " for output in theme_output:\n", " for label, score in zip(output[\"labels\"], output[\"scores\"]):\n", " if label not in themes:\n", " themes[label] = []\n", " themes[label].append(score)\n", " \n", " themes = {key:np.mean(np.array(value)) for key, value in themes.items()}\n", " return themes" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "opdf = get_theme_inference(script[:500])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "opdf = pd.Series(opdf)\n", "opdf" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "newdf = df.head(1)\n", "newdf[opdf.index] = opdf" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "newdf" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" } }, "nbformat": 4, "nbformat_minor": 2 }