{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "T4" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "code", "source": [ "!pip install transformers" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "6_gaeY1UMPOv", "outputId": "470ea044-c9b1-400e-f322-aafbdbae4aea" }, "execution_count": 9, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.31.0)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.12.2)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.16.4)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.22.4)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.1)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2022.10.31)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.27.1)\n", "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.13.3)\n", "Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.3.1)\n", "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.65.0)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (2023.6.0)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (4.7.1)\n", "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (1.26.16)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2023.7.22)\n", "Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.12)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4)\n" ] } ] }, { "cell_type": "code", "source": [ "!pip install peft" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "UkDCPUBOMh-L", "outputId": "0c618ade-6b5b-4500-8063-a51c29880fb4" }, "execution_count": 13, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Requirement already satisfied: peft in /usr/local/lib/python3.10/dist-packages (0.4.0)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from peft) (1.22.4)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from peft) (23.1)\n", "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from peft) (5.9.5)\n", "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from peft) (6.0.1)\n", "Requirement already satisfied: torch>=1.13.0 in /usr/local/lib/python3.10/dist-packages (from peft) (2.0.1+cu118)\n", "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (from peft) (4.31.0)\n", "Requirement already satisfied: accelerate in /usr/local/lib/python3.10/dist-packages (from peft) (0.21.0)\n", "Requirement already satisfied: safetensors in /usr/local/lib/python3.10/dist-packages (from peft) (0.3.1)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.12.2)\n", "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (4.7.1)\n", "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (1.11.1)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.1)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.1.2)\n", "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (2.0.0)\n", "Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.13.0->peft) (3.25.2)\n", "Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.13.0->peft) (16.0.6)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (0.16.4)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (2022.10.31)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (2.27.1)\n", "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (0.13.3)\n", "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (4.65.0)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers->peft) (2023.6.0)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.13.0->peft) (2.1.3)\n", "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->peft) (1.26.16)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->peft) (2023.7.22)\n", "Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->peft) (2.0.12)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->peft) (3.4)\n", "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.13.0->peft) (1.3.0)\n" ] } ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "6YOhmSaCMK2M" }, "outputs": [], "source": [ "# from transformers import AutoModelForSeq2SeqLM\n", "# from peft import get_peft_config, get_peft_model, LoraConfig, TaskType\n", "# import torch\n", "# model_name_or_path = \"microsoft/GODEL-v1_1-large-seq2seq\"\n", "# tokenizer_name_or_path = \"microsoft/GODEL-v1_1-large-seq2seq\"" ] }, { "cell_type": "code", "source": [ "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n", "\n", "# Replace 'microsoft/GODEL-v1_1-large-seq2seq' with the model name\n", "model_name = 'microsoft/GODEL-v1_1-large-seq2seq'\n", "\n", "# Load the model and tokenizer\n", "model = AutoModelForSeq2SeqLM.from_pretrained(model_name)\n", "tokenizer = AutoTokenizer.from_pretrained(model_name)" ], "metadata": { "id": "r1zRNhfYXN8T" }, "execution_count": 2, "outputs": [] }, { "cell_type": "code", "source": [ "# Output directory\n", "output_dir = \"medbot_model\"\n", "\n", "# Save the model and tokenizer using the standard Hugging Face naming convention\n", "model.save_pretrained(output_dir)\n", "tokenizer.save_pretrained(output_dir)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "UjV85bPQXw7P", "outputId": "688d07cb-eddd-4a6a-819e-57efd837324b" }, "execution_count": 15, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "('medbot_model/tokenizer_config.json',\n", " 'medbot_model/special_tokens_map.json',\n", " 'medbot_model/tokenizer.json')" ] }, "metadata": {}, "execution_count": 15 } ] }, { "cell_type": "code", "source": [ "# # peft config\n", "\n", "# peft_config = LoraConfig(\n", "# task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=6, lora_alpha=16, lora_dropout=0.2\n", "# )" ], "metadata": { "id": "qmIGSnctujOh" }, "execution_count": 12, "outputs": [] }, { "cell_type": "code", "source": [ "# model = get_peft_model(model, peft_config)\n", "# model.print_trainable_parameters()\n", "\n", "# output_dir = \"medbot_model_peft\"\n", "\n", "# model.save_pretrained(output_dir)\n", "# tokenizer.save_pretrained(output_dir)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "RulB42QiMOhi", "outputId": "e8e2d65d-8afa-4095-bf8b-93749e39b785" }, "execution_count": 14, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "trainable params: 1,769,472 || all params: 739,410,944 || trainable%: 0.23930833244469804\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "('medbot_model_peft/tokenizer_config.json',\n", " 'medbot_model_peft/special_tokens_map.json',\n", " 'medbot_model_peft/tokenizer.json')" ] }, "metadata": {}, "execution_count": 14 } ] }, { "cell_type": "code", "source": [ "# ============================== Load Dataset ==========================" ], "metadata": { "id": "Xj4K4WU-NYp8" }, "execution_count": 8, "outputs": [] }, { "cell_type": "code", "source": [ "import pandas as pd\n", "\n", "df = pd.read_csv('/content/drive/MyDrive/Dataset/diseaseDataSetFull2.csv')\n", "df" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 607 }, "id": "Wc8bxpybNgFo", "outputId": "e5e960a5-e460-4b27-ea2f-bfb86dbbb06b" }, "execution_count": 3, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " disease \\\n", "0 Fungal infection \n", "1 Fungal infection \n", "2 Fungal infection \n", "3 Fungal infection \n", "4 Fungal infection \n", "... ... \n", "4915 (vertigo) Paroymsal Positional Vertigo \n", "4916 Acne \n", "4917 Urinary tract infection \n", "4918 Psoriasis \n", "4919 Impetigo \n", "\n", " symptoms \\\n", "0 itching,skin_rash,nodal_skin_eruptions,dischro... \n", "1 skin_rash,nodal_skin_eruptions,dischromic__pat... \n", "2 itching,nodal_skin_eruptions,dischromic__patches \n", "3 itching,skin_rash,dischromic__patches \n", "4 itching,skin_rash,nodal_skin_eruptions \n", "... ... \n", "4915 vomiting,headache,nausea,spinning_movements,lo... \n", "4916 skin_rash,pus_filled_pimples,blackheads,scurring \n", "4917 burning_micturition,bladder_discomfort,foul_sm... \n", "4918 skin_rash,joint_pain,skin_peeling,silver_like_... \n", "4919 skin_rash,high_fever,blister,red_sore_around_n... \n", "\n", " precautions \n", "0 bath twice, use detol or neem in bathing water... \n", "1 bath twice, use detol or neem in bathing water... \n", "2 bath twice, use detol or neem in bathing water... \n", "3 bath twice, use detol or neem in bathing water... \n", "4 bath twice, use detol or neem in bathing water... \n", "... ... \n", "4915 lie down, avoid sudden change in body, avoid a... \n", "4916 bath twice, avoid fatty spicy food, drink plen... \n", "4917 drink plenty of water, increase vitamin c inta... \n", "4918 wash hands with warm soapy water, stop bleedin... \n", "4919 soak affected area in warm water, use antibiot... \n", "\n", "[4920 rows x 3 columns]" ], "text/html": [ "\n", "\n", "
\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
diseasesymptomsprecautions
0Fungal infectionitching,skin_rash,nodal_skin_eruptions,dischro...bath twice, use detol or neem in bathing water...
1Fungal infectionskin_rash,nodal_skin_eruptions,dischromic__pat...bath twice, use detol or neem in bathing water...
2Fungal infectionitching,nodal_skin_eruptions,dischromic__patchesbath twice, use detol or neem in bathing water...
3Fungal infectionitching,skin_rash,dischromic__patchesbath twice, use detol or neem in bathing water...
4Fungal infectionitching,skin_rash,nodal_skin_eruptionsbath twice, use detol or neem in bathing water...
............
4915(vertigo) Paroymsal Positional Vertigovomiting,headache,nausea,spinning_movements,lo...lie down, avoid sudden change in body, avoid a...
4916Acneskin_rash,pus_filled_pimples,blackheads,scurringbath twice, avoid fatty spicy food, drink plen...
4917Urinary tract infectionburning_micturition,bladder_discomfort,foul_sm...drink plenty of water, increase vitamin c inta...
4918Psoriasisskin_rash,joint_pain,skin_peeling,silver_like_...wash hands with warm soapy water, stop bleedin...
4919Impetigoskin_rash,high_fever,blister,red_sore_around_n...soak affected area in warm water, use antibiot...
\n", "

4920 rows × 3 columns

\n", "
\n", " \n", "\n", "\n", "\n", "
\n", " \n", "
\n", "\n", "\n", "\n", " \n", "\n", " \n", " \n", "\n", " \n", "
\n", "
\n" ] }, "metadata": {}, "execution_count": 3 } ] }, { "cell_type": "code", "source": [ "def dataframe_to_dataset(df):\n", " \"\"\"\n", " Convert a DataFrame with columns 'disease', 'symptoms', and 'precautions'\n", " into a list of tuples dataset.\n", "\n", " Parameters:\n", " df (pd.DataFrame): Input DataFrame with columns 'disease', 'symptoms', and 'precautions'.\n", "\n", " Returns:\n", " list: A list of tuples, where each tuple contains information about a specific disease,\n", " symptoms, and precautions.\n", " \"\"\"\n", " if not all(col in df.columns for col in ['disease', 'symptoms', 'precautions']):\n", " raise ValueError(\"DataFrame must contain 'disease', 'symptoms', and 'precautions' columns.\")\n", "\n", " dataset = []\n", " for _, row in df.iterrows():\n", " disease = row['disease']\n", " symptoms = row['symptoms']\n", " precautions = row['precautions']\n", " dataset.append((disease, symptoms, precautions))\n", "\n", " return dataset\n", "\n", "data = dataframe_to_dataset(df)\n", "data[:10]" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "rGbPHffgNi72", "outputId": "48542ee0-d5ab-4cae-d984-1dae31c77bd4" }, "execution_count": 4, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "[('Fungal infection',\n", " 'itching,skin_rash,nodal_skin_eruptions,dischromic__patches',\n", " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n", " ('Fungal infection',\n", " 'skin_rash,nodal_skin_eruptions,dischromic__patches',\n", " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n", " ('Fungal infection',\n", " 'itching,nodal_skin_eruptions,dischromic__patches',\n", " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n", " ('Fungal infection',\n", " 'itching,skin_rash,dischromic__patches',\n", " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n", " ('Fungal infection',\n", " 'itching,skin_rash,nodal_skin_eruptions',\n", " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n", " ('Fungal infection',\n", " 'skin_rash,nodal_skin_eruptions,dischromic__patches',\n", " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n", " ('Fungal infection',\n", " 'itching,nodal_skin_eruptions,dischromic__patches',\n", " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n", " ('Fungal infection',\n", " 'itching,skin_rash,dischromic__patches',\n", " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n", " ('Fungal infection',\n", " 'itching,skin_rash,nodal_skin_eruptions',\n", " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n", " ('Fungal infection',\n", " 'itching,skin_rash,nodal_skin_eruptions,dischromic__patches',\n", " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths')]" ] }, "metadata": {}, "execution_count": 4 } ] }, { "cell_type": "code", "source": [ "data[:10]" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "VOG5lBemvYei", "outputId": "9b8013a4-8273-4a51-9f1e-59566c9d4892" }, "execution_count": 5, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "[('Fungal infection',\n", " 'itching,skin_rash,nodal_skin_eruptions,dischromic__patches',\n", " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n", " ('Fungal infection',\n", " 'skin_rash,nodal_skin_eruptions,dischromic__patches',\n", " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n", " ('Fungal infection',\n", " 'itching,nodal_skin_eruptions,dischromic__patches',\n", " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n", " ('Fungal infection',\n", " 'itching,skin_rash,dischromic__patches',\n", " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n", " ('Fungal infection',\n", " 'itching,skin_rash,nodal_skin_eruptions',\n", " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n", " ('Fungal infection',\n", " 'skin_rash,nodal_skin_eruptions,dischromic__patches',\n", " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n", " ('Fungal infection',\n", " 'itching,nodal_skin_eruptions,dischromic__patches',\n", " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n", " ('Fungal infection',\n", " 'itching,skin_rash,dischromic__patches',\n", " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n", " ('Fungal infection',\n", " 'itching,skin_rash,nodal_skin_eruptions',\n", " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n", " ('Fungal infection',\n", " 'itching,skin_rash,nodal_skin_eruptions,dischromic__patches',\n", " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths')]" ] }, "metadata": {}, "execution_count": 5 } ] }, { "cell_type": "code", "source": [ "# =============================== Training =====================================" ], "metadata": { "id": "2AoLTRuhNSyp" }, "execution_count": 7, "outputs": [] }, { "cell_type": "code", "source": [ "from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AdamW\n", "from torch.utils.data import Dataset, DataLoader\n", "import torch\n", "\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "model = model.to(device)\n", "\n", "# Sample Data\n", "sample_data = data\n", "\n", "class CustomDataset(Dataset):\n", " def __init__(self, data, tokenizer, max_length):\n", " self.data = data\n", " self.tokenizer = tokenizer\n", " self.max_length = max_length\n", "\n", " def __len__(self):\n", " return len(self.data)\n", "\n", " def __getitem__(self, index):\n", " disease, symptoms, precautions = self.data[index]\n", " source_text = f\"I am feeling {symptoms}\"\n", " target_text = f\"You might have {disease}, the precautions are {precautions}\"\n", "\n", " # Tokenize the source and target texts separately\n", " source_tokens = self.tokenizer(source_text, padding=\"max_length\", max_length=self.max_length, return_tensors=\"pt\")\n", " target_tokens = self.tokenizer(target_text, padding=\"max_length\", max_length=self.max_length, return_tensors=\"pt\")\n", "\n", " # Prepare the inputs and labels for the Seq2Seq model\n", " input_ids = source_tokens.input_ids.squeeze()\n", " attention_mask = source_tokens.attention_mask.squeeze()\n", " labels = target_tokens.input_ids.squeeze()\n", "\n", " return {\n", " \"input_ids\": input_ids,\n", " \"attention_mask\": attention_mask,\n", " \"labels\": labels,\n", " }\n", "\n", "def fine_tune_and_save_model(model, tokenizer):\n", " # Load tokenizer and create dataset\n", " # checkpoint = \"microsoft/GODEL-v1_1-large-seq2seq\"\n", " # tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n", " max_length = 128 # You can adjust this based on your input sequence length requirements\n", " dataset = CustomDataset(sample_data, tokenizer, max_length)\n", "\n", " # Data loader\n", " batch_size = 2\n", " dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n", "\n", " # Load the model\n", " # from the parameter\n", "\n", " # Hyperparameters\n", " learning_rate = 2e-5\n", " num_epochs = 2\n", " num_warmup_steps = 100\n", "\n", " optimizer = AdamW(model.parameters(), lr=learning_rate)\n", "\n", " # Training loop\n", " model.train()\n", " for epoch in range(num_epochs):\n", " total_loss = 0.0\n", " for batch in dataloader:\n", " optimizer.zero_grad()\n", "\n", " input_ids = batch[\"input_ids\"]\n", " attention_mask = batch[\"attention_mask\"]\n", " labels = batch[\"labels\"]\n", "\n", " input_ids = batch[\"input_ids\"].to(device)\n", " attention_mask = batch[\"attention_mask\"].to(device)\n", " labels = batch[\"labels\"].to(device)\n", "\n", "\n", " outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)\n", " loss = outputs.loss\n", " total_loss += loss.item()\n", "\n", " loss.backward()\n", " optimizer.step()\n", "\n", " average_loss = total_loss / len(dataloader)\n", " print(f\"Epoch {epoch+1}/{num_epochs} - Average Loss: {average_loss:.4f}\")\n", "\n", " # Save the fine-tuned model and tokenizer\n", " output_dir = \"medbot_model_epoch3_s512\"\n", " model.save_pretrained(output_dir)\n", " tokenizer.save_pretrained(output_dir)" ], "metadata": { "id": "4COYhQqYM0ni" }, "execution_count": 6, "outputs": [] }, { "cell_type": "code", "source": [ "fine_tune_and_save_model(model, tokenizer)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "j3p4lBPbOZZP", "outputId": "91397825-200d-4e4a-f4e7-29e9df6c040c" }, "execution_count": 7, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.10/dist-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", " warnings.warn(\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "Epoch 1/2 - Average Loss: 0.1588\n", "Epoch 2/2 - Average Loss: 0.0038\n" ] } ] }, { "cell_type": "code", "source": [], "metadata": { "id": "M9dy5RBfRcCH" }, "execution_count": null, "outputs": [] } ] }