{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "e0102cb4", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Global seed set to 100\n" ] }, { "data": { "text/plain": [ "100" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from transformers import T5Tokenizer, T5ForConditionalGeneration \n", "\n", "from transformers import AdamW\n", "import pandas as pd\n", "import torch\n", "import pytorch_lightning as pl\n", "from pytorch_lightning.callbacks import ModelCheckpoint\n", "from torch.nn.utils.rnn import pad_sequence\n", "# from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler\n", "\n", "pl.seed_everything(100)" ] }, { "cell_type": "code", "execution_count": 2, "id": "1ec5ec2a", "metadata": {}, "outputs": [], "source": [ "MODEL_NAME='t5-base'" ] }, { "cell_type": "code", "execution_count": 3, "id": "8044c622", "metadata": {}, "outputs": [], "source": [ "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "INPUT_MAX_LEN = 128 \n", "OUTPUT_MAX_LEN = 128" ] }, { "cell_type": "code", "execution_count": 4, "id": "6390f2de", "metadata": {}, "outputs": [], "source": [ "tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, model_max_length=512)" ] }, { "cell_type": "code", "execution_count": 5, "id": "8eec35d1", "metadata": {}, "outputs": [], "source": [ "class T5Model(pl.LightningModule):\n", " \n", " def __init__(self):\n", " super().__init__()\n", " self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict = True)\n", "\n", " \n", " def forward(self, input_ids, attention_mask, labels=None):\n", " \n", " output = self.model(\n", " input_ids=input_ids, \n", " attention_mask=attention_mask, \n", " labels=labels\n", " )\n", " return output.loss, output.logits\n", " \n", " def training_step(self, batch, batch_idx):\n", "\n", " input_ids = batch[\"input_ids\"]\n", " attention_mask = batch[\"attention_mask\"]\n", " labels= batch[\"target\"]\n", " loss, logits = self(input_ids , attention_mask, labels)\n", "\n", " \n", " self.log(\"train_loss\", loss, prog_bar=True, logger=True)\n", "\n", " return {'loss': loss}\n", " \n", " def validation_step(self, batch, batch_idx):\n", " input_ids = batch[\"input_ids\"]\n", " attention_mask = batch[\"attention_mask\"]\n", " labels= batch[\"target\"]\n", " loss, logits = self(input_ids, attention_mask, labels)\n", "\n", " self.log(\"val_loss\", loss, prog_bar=True, logger=True)\n", " \n", " return {'val_loss': loss}\n", "\n", " def configure_optimizers(self):\n", " return AdamW(self.parameters(), lr=0.0001)" ] }, { "cell_type": "code", "execution_count": 71, "id": "e9d96844", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Lightning automatically upgraded your loaded checkpoint from v1.9.4 to v2.0.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file F:\\Projects & Open_source\\Chatbot_T5_kaggle\\Sarcastic_chatbot\\best-model.ckpt`\n" ] } ], "source": [ "train_model = T5Model.load_from_checkpoint('best-model.ckpt',map_location=DEVICE)" ] }, { "cell_type": "code", "execution_count": 72, "id": "3449943f", "metadata": {}, "outputs": [], "source": [ "train_model.freeze()" ] }, { "cell_type": "code", "execution_count": 73, "id": "0e9f1058", "metadata": {}, "outputs": [], "source": [ "def generate_question(question):\n", "\n", " inputs_encoding = tokenizer(\n", " question,\n", " add_special_tokens=True,\n", " max_length= INPUT_MAX_LEN,\n", " padding = 'max_length',\n", " truncation='only_first',\n", " return_attention_mask=True,\n", " return_tensors=\"pt\"\n", " )\n", "\n", " \n", " generate_ids = train_model.model.generate(\n", " input_ids = inputs_encoding[\"input_ids\"],\n", " attention_mask = inputs_encoding[\"attention_mask\"],\n", " max_length = INPUT_MAX_LEN,\n", " num_beams = 4,\n", " num_return_sequences = 1,\n", " no_repeat_ngram_size=2,\n", " early_stopping=True,\n", " )\n", "\n", " preds = [\n", " tokenizer.decode(gen_id,\n", " skip_special_tokens=True, \n", " clean_up_tokenization_spaces=True)\n", " for gen_id in generate_ids\n", " ]\n", "\n", " return \"\".join(preds)\n" ] }, { "cell_type": "code", "execution_count": 74, "id": "ee38a88c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Ques: who is elon musk?\n", "BOT: He's a shitlord.\n" ] } ], "source": [ "ques = \"who is elon musk?\"\n", "print(\"Ques: \",ques)\n", "print(\"BOT: \",generate_question(ques))" ] }, { "cell_type": "code", "execution_count": 75, "id": "22aa4414", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7904\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "