shanover commited on
Commit
bc6e726
·
1 Parent(s): 9980b01

Upload 6 files

Browse files
config.json ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "microsoft/GODEL-v1_1-large-seq2seq",
3
+ "architectures": [
4
+ "T5ForConditionalGeneration"
5
+ ],
6
+ "d_ff": 4096,
7
+ "d_kv": 64,
8
+ "d_model": 1024,
9
+ "decoder_start_token_id": 0,
10
+ "dense_act_fn": "relu",
11
+ "dropout_rate": 0.1,
12
+ "eos_token_id": 1,
13
+ "feed_forward_proj": "relu",
14
+ "initializer_factor": 1.0,
15
+ "is_encoder_decoder": true,
16
+ "is_gated_act": false,
17
+ "layer_norm_epsilon": 1e-06,
18
+ "model_type": "t5",
19
+ "n_positions": 512,
20
+ "num_decoder_layers": 24,
21
+ "num_heads": 16,
22
+ "num_layers": 24,
23
+ "output_past": true,
24
+ "pad_token_id": 0,
25
+ "relative_attention_max_distance": 128,
26
+ "relative_attention_num_buckets": 32,
27
+ "task_specific_params": {
28
+ "summarization": {
29
+ "early_stopping": true,
30
+ "length_penalty": 2.0,
31
+ "max_length": 200,
32
+ "min_length": 30,
33
+ "no_repeat_ngram_size": 3,
34
+ "num_beams": 4,
35
+ "prefix": "summarize: "
36
+ },
37
+ "translation_en_to_de": {
38
+ "early_stopping": true,
39
+ "max_length": 300,
40
+ "num_beams": 4,
41
+ "prefix": "translate English to German: "
42
+ },
43
+ "translation_en_to_fr": {
44
+ "early_stopping": true,
45
+ "max_length": 300,
46
+ "num_beams": 4,
47
+ "prefix": "translate English to French: "
48
+ },
49
+ "translation_en_to_ro": {
50
+ "early_stopping": true,
51
+ "max_length": 300,
52
+ "num_beams": 4,
53
+ "prefix": "translate English to Romanian: "
54
+ }
55
+ },
56
+ "torch_dtype": "float32",
57
+ "transformers_version": "4.31.0",
58
+ "use_cache": true,
59
+ "vocab_size": 32102
60
+ }
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "decoder_start_token_id": 0,
4
+ "eos_token_id": 1,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.31.0"
7
+ }
peft_attempt_1.ipynb ADDED
@@ -0,0 +1,841 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "language_info": {
14
+ "name": "python"
15
+ },
16
+ "accelerator": "GPU"
17
+ },
18
+ "cells": [
19
+ {
20
+ "cell_type": "code",
21
+ "source": [
22
+ "!pip install transformers"
23
+ ],
24
+ "metadata": {
25
+ "colab": {
26
+ "base_uri": "https://localhost:8080/"
27
+ },
28
+ "id": "6_gaeY1UMPOv",
29
+ "outputId": "470ea044-c9b1-400e-f322-aafbdbae4aea"
30
+ },
31
+ "execution_count": 9,
32
+ "outputs": [
33
+ {
34
+ "output_type": "stream",
35
+ "name": "stdout",
36
+ "text": [
37
+ "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.31.0)\n",
38
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.12.2)\n",
39
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.16.4)\n",
40
+ "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.22.4)\n",
41
+ "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.1)\n",
42
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n",
43
+ "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2022.10.31)\n",
44
+ "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.27.1)\n",
45
+ "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.13.3)\n",
46
+ "Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.3.1)\n",
47
+ "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.65.0)\n",
48
+ "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (2023.6.0)\n",
49
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (4.7.1)\n",
50
+ "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (1.26.16)\n",
51
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2023.7.22)\n",
52
+ "Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.12)\n",
53
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4)\n"
54
+ ]
55
+ }
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "source": [
61
+ "!pip install peft"
62
+ ],
63
+ "metadata": {
64
+ "colab": {
65
+ "base_uri": "https://localhost:8080/"
66
+ },
67
+ "id": "UkDCPUBOMh-L",
68
+ "outputId": "0c618ade-6b5b-4500-8063-a51c29880fb4"
69
+ },
70
+ "execution_count": 13,
71
+ "outputs": [
72
+ {
73
+ "output_type": "stream",
74
+ "name": "stdout",
75
+ "text": [
76
+ "Requirement already satisfied: peft in /usr/local/lib/python3.10/dist-packages (0.4.0)\n",
77
+ "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from peft) (1.22.4)\n",
78
+ "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from peft) (23.1)\n",
79
+ "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from peft) (5.9.5)\n",
80
+ "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from peft) (6.0.1)\n",
81
+ "Requirement already satisfied: torch>=1.13.0 in /usr/local/lib/python3.10/dist-packages (from peft) (2.0.1+cu118)\n",
82
+ "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (from peft) (4.31.0)\n",
83
+ "Requirement already satisfied: accelerate in /usr/local/lib/python3.10/dist-packages (from peft) (0.21.0)\n",
84
+ "Requirement already satisfied: safetensors in /usr/local/lib/python3.10/dist-packages (from peft) (0.3.1)\n",
85
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.12.2)\n",
86
+ "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (4.7.1)\n",
87
+ "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (1.11.1)\n",
88
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.1)\n",
89
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.1.2)\n",
90
+ "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (2.0.0)\n",
91
+ "Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.13.0->peft) (3.25.2)\n",
92
+ "Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.13.0->peft) (16.0.6)\n",
93
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (0.16.4)\n",
94
+ "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (2022.10.31)\n",
95
+ "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (2.27.1)\n",
96
+ "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (0.13.3)\n",
97
+ "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (4.65.0)\n",
98
+ "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers->peft) (2023.6.0)\n",
99
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.13.0->peft) (2.1.3)\n",
100
+ "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->peft) (1.26.16)\n",
101
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->peft) (2023.7.22)\n",
102
+ "Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->peft) (2.0.12)\n",
103
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->peft) (3.4)\n",
104
+ "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.13.0->peft) (1.3.0)\n"
105
+ ]
106
+ }
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": 2,
112
+ "metadata": {
113
+ "id": "6YOhmSaCMK2M"
114
+ },
115
+ "outputs": [],
116
+ "source": [
117
+ "# from transformers import AutoModelForSeq2SeqLM\n",
118
+ "# from peft import get_peft_config, get_peft_model, LoraConfig, TaskType\n",
119
+ "# import torch\n",
120
+ "# model_name_or_path = \"microsoft/GODEL-v1_1-large-seq2seq\"\n",
121
+ "# tokenizer_name_or_path = \"microsoft/GODEL-v1_1-large-seq2seq\""
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "source": [
127
+ "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n",
128
+ "\n",
129
+ "# Replace 'microsoft/GODEL-v1_1-large-seq2seq' with the model name\n",
130
+ "model_name = 'microsoft/GODEL-v1_1-large-seq2seq'\n",
131
+ "\n",
132
+ "# Load the model and tokenizer\n",
133
+ "model = AutoModelForSeq2SeqLM.from_pretrained(model_name)\n",
134
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)"
135
+ ],
136
+ "metadata": {
137
+ "id": "r1zRNhfYXN8T"
138
+ },
139
+ "execution_count": 2,
140
+ "outputs": []
141
+ },
142
+ {
143
+ "cell_type": "code",
144
+ "source": [
145
+ "# Output directory\n",
146
+ "output_dir = \"medbot_model\"\n",
147
+ "\n",
148
+ "# Save the model and tokenizer using the standard Hugging Face naming convention\n",
149
+ "model.save_pretrained(output_dir)\n",
150
+ "tokenizer.save_pretrained(output_dir)"
151
+ ],
152
+ "metadata": {
153
+ "colab": {
154
+ "base_uri": "https://localhost:8080/"
155
+ },
156
+ "id": "UjV85bPQXw7P",
157
+ "outputId": "688d07cb-eddd-4a6a-819e-57efd837324b"
158
+ },
159
+ "execution_count": 15,
160
+ "outputs": [
161
+ {
162
+ "output_type": "execute_result",
163
+ "data": {
164
+ "text/plain": [
165
+ "('medbot_model/tokenizer_config.json',\n",
166
+ " 'medbot_model/special_tokens_map.json',\n",
167
+ " 'medbot_model/tokenizer.json')"
168
+ ]
169
+ },
170
+ "metadata": {},
171
+ "execution_count": 15
172
+ }
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "source": [
178
+ "# # peft config\n",
179
+ "\n",
180
+ "# peft_config = LoraConfig(\n",
181
+ "# task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=6, lora_alpha=16, lora_dropout=0.2\n",
182
+ "# )"
183
+ ],
184
+ "metadata": {
185
+ "id": "qmIGSnctujOh"
186
+ },
187
+ "execution_count": 12,
188
+ "outputs": []
189
+ },
190
+ {
191
+ "cell_type": "code",
192
+ "source": [
193
+ "# model = get_peft_model(model, peft_config)\n",
194
+ "# model.print_trainable_parameters()\n",
195
+ "\n",
196
+ "# output_dir = \"medbot_model_peft\"\n",
197
+ "\n",
198
+ "# model.save_pretrained(output_dir)\n",
199
+ "# tokenizer.save_pretrained(output_dir)"
200
+ ],
201
+ "metadata": {
202
+ "colab": {
203
+ "base_uri": "https://localhost:8080/"
204
+ },
205
+ "id": "RulB42QiMOhi",
206
+ "outputId": "e8e2d65d-8afa-4095-bf8b-93749e39b785"
207
+ },
208
+ "execution_count": 14,
209
+ "outputs": [
210
+ {
211
+ "output_type": "stream",
212
+ "name": "stdout",
213
+ "text": [
214
+ "trainable params: 1,769,472 || all params: 739,410,944 || trainable%: 0.23930833244469804\n"
215
+ ]
216
+ },
217
+ {
218
+ "output_type": "execute_result",
219
+ "data": {
220
+ "text/plain": [
221
+ "('medbot_model_peft/tokenizer_config.json',\n",
222
+ " 'medbot_model_peft/special_tokens_map.json',\n",
223
+ " 'medbot_model_peft/tokenizer.json')"
224
+ ]
225
+ },
226
+ "metadata": {},
227
+ "execution_count": 14
228
+ }
229
+ ]
230
+ },
231
+ {
232
+ "cell_type": "code",
233
+ "source": [
234
+ "# ============================== Load Dataset =========================="
235
+ ],
236
+ "metadata": {
237
+ "id": "Xj4K4WU-NYp8"
238
+ },
239
+ "execution_count": 8,
240
+ "outputs": []
241
+ },
242
+ {
243
+ "cell_type": "code",
244
+ "source": [
245
+ "import pandas as pd\n",
246
+ "\n",
247
+ "df = pd.read_csv('/content/drive/MyDrive/Dataset/diseaseDataSetFull2.csv')\n",
248
+ "df"
249
+ ],
250
+ "metadata": {
251
+ "colab": {
252
+ "base_uri": "https://localhost:8080/",
253
+ "height": 607
254
+ },
255
+ "id": "Wc8bxpybNgFo",
256
+ "outputId": "e5e960a5-e460-4b27-ea2f-bfb86dbbb06b"
257
+ },
258
+ "execution_count": 3,
259
+ "outputs": [
260
+ {
261
+ "output_type": "execute_result",
262
+ "data": {
263
+ "text/plain": [
264
+ " disease \\\n",
265
+ "0 Fungal infection \n",
266
+ "1 Fungal infection \n",
267
+ "2 Fungal infection \n",
268
+ "3 Fungal infection \n",
269
+ "4 Fungal infection \n",
270
+ "... ... \n",
271
+ "4915 (vertigo) Paroymsal Positional Vertigo \n",
272
+ "4916 Acne \n",
273
+ "4917 Urinary tract infection \n",
274
+ "4918 Psoriasis \n",
275
+ "4919 Impetigo \n",
276
+ "\n",
277
+ " symptoms \\\n",
278
+ "0 itching,skin_rash,nodal_skin_eruptions,dischro... \n",
279
+ "1 skin_rash,nodal_skin_eruptions,dischromic__pat... \n",
280
+ "2 itching,nodal_skin_eruptions,dischromic__patches \n",
281
+ "3 itching,skin_rash,dischromic__patches \n",
282
+ "4 itching,skin_rash,nodal_skin_eruptions \n",
283
+ "... ... \n",
284
+ "4915 vomiting,headache,nausea,spinning_movements,lo... \n",
285
+ "4916 skin_rash,pus_filled_pimples,blackheads,scurring \n",
286
+ "4917 burning_micturition,bladder_discomfort,foul_sm... \n",
287
+ "4918 skin_rash,joint_pain,skin_peeling,silver_like_... \n",
288
+ "4919 skin_rash,high_fever,blister,red_sore_around_n... \n",
289
+ "\n",
290
+ " precautions \n",
291
+ "0 bath twice, use detol or neem in bathing water... \n",
292
+ "1 bath twice, use detol or neem in bathing water... \n",
293
+ "2 bath twice, use detol or neem in bathing water... \n",
294
+ "3 bath twice, use detol or neem in bathing water... \n",
295
+ "4 bath twice, use detol or neem in bathing water... \n",
296
+ "... ... \n",
297
+ "4915 lie down, avoid sudden change in body, avoid a... \n",
298
+ "4916 bath twice, avoid fatty spicy food, drink plen... \n",
299
+ "4917 drink plenty of water, increase vitamin c inta... \n",
300
+ "4918 wash hands with warm soapy water, stop bleedin... \n",
301
+ "4919 soak affected area in warm water, use antibiot... \n",
302
+ "\n",
303
+ "[4920 rows x 3 columns]"
304
+ ],
305
+ "text/html": [
306
+ "\n",
307
+ "\n",
308
+ " <div id=\"df-d8eb48fc-bfd8-4158-80d6-468b1560edca\">\n",
309
+ " <div class=\"colab-df-container\">\n",
310
+ " <div>\n",
311
+ "<style scoped>\n",
312
+ " .dataframe tbody tr th:only-of-type {\n",
313
+ " vertical-align: middle;\n",
314
+ " }\n",
315
+ "\n",
316
+ " .dataframe tbody tr th {\n",
317
+ " vertical-align: top;\n",
318
+ " }\n",
319
+ "\n",
320
+ " .dataframe thead th {\n",
321
+ " text-align: right;\n",
322
+ " }\n",
323
+ "</style>\n",
324
+ "<table border=\"1\" class=\"dataframe\">\n",
325
+ " <thead>\n",
326
+ " <tr style=\"text-align: right;\">\n",
327
+ " <th></th>\n",
328
+ " <th>disease</th>\n",
329
+ " <th>symptoms</th>\n",
330
+ " <th>precautions</th>\n",
331
+ " </tr>\n",
332
+ " </thead>\n",
333
+ " <tbody>\n",
334
+ " <tr>\n",
335
+ " <th>0</th>\n",
336
+ " <td>Fungal infection</td>\n",
337
+ " <td>itching,skin_rash,nodal_skin_eruptions,dischro...</td>\n",
338
+ " <td>bath twice, use detol or neem in bathing water...</td>\n",
339
+ " </tr>\n",
340
+ " <tr>\n",
341
+ " <th>1</th>\n",
342
+ " <td>Fungal infection</td>\n",
343
+ " <td>skin_rash,nodal_skin_eruptions,dischromic__pat...</td>\n",
344
+ " <td>bath twice, use detol or neem in bathing water...</td>\n",
345
+ " </tr>\n",
346
+ " <tr>\n",
347
+ " <th>2</th>\n",
348
+ " <td>Fungal infection</td>\n",
349
+ " <td>itching,nodal_skin_eruptions,dischromic__patches</td>\n",
350
+ " <td>bath twice, use detol or neem in bathing water...</td>\n",
351
+ " </tr>\n",
352
+ " <tr>\n",
353
+ " <th>3</th>\n",
354
+ " <td>Fungal infection</td>\n",
355
+ " <td>itching,skin_rash,dischromic__patches</td>\n",
356
+ " <td>bath twice, use detol or neem in bathing water...</td>\n",
357
+ " </tr>\n",
358
+ " <tr>\n",
359
+ " <th>4</th>\n",
360
+ " <td>Fungal infection</td>\n",
361
+ " <td>itching,skin_rash,nodal_skin_eruptions</td>\n",
362
+ " <td>bath twice, use detol or neem in bathing water...</td>\n",
363
+ " </tr>\n",
364
+ " <tr>\n",
365
+ " <th>...</th>\n",
366
+ " <td>...</td>\n",
367
+ " <td>...</td>\n",
368
+ " <td>...</td>\n",
369
+ " </tr>\n",
370
+ " <tr>\n",
371
+ " <th>4915</th>\n",
372
+ " <td>(vertigo) Paroymsal Positional Vertigo</td>\n",
373
+ " <td>vomiting,headache,nausea,spinning_movements,lo...</td>\n",
374
+ " <td>lie down, avoid sudden change in body, avoid a...</td>\n",
375
+ " </tr>\n",
376
+ " <tr>\n",
377
+ " <th>4916</th>\n",
378
+ " <td>Acne</td>\n",
379
+ " <td>skin_rash,pus_filled_pimples,blackheads,scurring</td>\n",
380
+ " <td>bath twice, avoid fatty spicy food, drink plen...</td>\n",
381
+ " </tr>\n",
382
+ " <tr>\n",
383
+ " <th>4917</th>\n",
384
+ " <td>Urinary tract infection</td>\n",
385
+ " <td>burning_micturition,bladder_discomfort,foul_sm...</td>\n",
386
+ " <td>drink plenty of water, increase vitamin c inta...</td>\n",
387
+ " </tr>\n",
388
+ " <tr>\n",
389
+ " <th>4918</th>\n",
390
+ " <td>Psoriasis</td>\n",
391
+ " <td>skin_rash,joint_pain,skin_peeling,silver_like_...</td>\n",
392
+ " <td>wash hands with warm soapy water, stop bleedin...</td>\n",
393
+ " </tr>\n",
394
+ " <tr>\n",
395
+ " <th>4919</th>\n",
396
+ " <td>Impetigo</td>\n",
397
+ " <td>skin_rash,high_fever,blister,red_sore_around_n...</td>\n",
398
+ " <td>soak affected area in warm water, use antibiot...</td>\n",
399
+ " </tr>\n",
400
+ " </tbody>\n",
401
+ "</table>\n",
402
+ "<p>4920 rows × 3 columns</p>\n",
403
+ "</div>\n",
404
+ " <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-d8eb48fc-bfd8-4158-80d6-468b1560edca')\"\n",
405
+ " title=\"Convert this dataframe to an interactive table.\"\n",
406
+ " style=\"display:none;\">\n",
407
+ "\n",
408
+ " <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
409
+ " width=\"24px\">\n",
410
+ " <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
411
+ " <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
412
+ " </svg>\n",
413
+ " </button>\n",
414
+ "\n",
415
+ "\n",
416
+ "\n",
417
+ " <div id=\"df-effcd91d-4d34-4d62-9ee1-b5487b3c0b00\">\n",
418
+ " <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-effcd91d-4d34-4d62-9ee1-b5487b3c0b00')\"\n",
419
+ " title=\"Suggest charts.\"\n",
420
+ " style=\"display:none;\">\n",
421
+ "\n",
422
+ "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
423
+ " width=\"24px\">\n",
424
+ " <g>\n",
425
+ " <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
426
+ " </g>\n",
427
+ "</svg>\n",
428
+ " </button>\n",
429
+ " </div>\n",
430
+ "\n",
431
+ "<style>\n",
432
+ " .colab-df-quickchart {\n",
433
+ " background-color: #E8F0FE;\n",
434
+ " border: none;\n",
435
+ " border-radius: 50%;\n",
436
+ " cursor: pointer;\n",
437
+ " display: none;\n",
438
+ " fill: #1967D2;\n",
439
+ " height: 32px;\n",
440
+ " padding: 0 0 0 0;\n",
441
+ " width: 32px;\n",
442
+ " }\n",
443
+ "\n",
444
+ " .colab-df-quickchart:hover {\n",
445
+ " background-color: #E2EBFA;\n",
446
+ " box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
447
+ " fill: #174EA6;\n",
448
+ " }\n",
449
+ "\n",
450
+ " [theme=dark] .colab-df-quickchart {\n",
451
+ " background-color: #3B4455;\n",
452
+ " fill: #D2E3FC;\n",
453
+ " }\n",
454
+ "\n",
455
+ " [theme=dark] .colab-df-quickchart:hover {\n",
456
+ " background-color: #434B5C;\n",
457
+ " box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
458
+ " filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
459
+ " fill: #FFFFFF;\n",
460
+ " }\n",
461
+ "</style>\n",
462
+ "\n",
463
+ " <script>\n",
464
+ " async function quickchart(key) {\n",
465
+ " const containerElement = document.querySelector('#' + key);\n",
466
+ " const charts = await google.colab.kernel.invokeFunction(\n",
467
+ " 'suggestCharts', [key], {});\n",
468
+ " }\n",
469
+ " </script>\n",
470
+ "\n",
471
+ " <script>\n",
472
+ "\n",
473
+ "function displayQuickchartButton(domScope) {\n",
474
+ " let quickchartButtonEl =\n",
475
+ " domScope.querySelector('#df-effcd91d-4d34-4d62-9ee1-b5487b3c0b00 button.colab-df-quickchart');\n",
476
+ " quickchartButtonEl.style.display =\n",
477
+ " google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
478
+ "}\n",
479
+ "\n",
480
+ " displayQuickchartButton(document);\n",
481
+ " </script>\n",
482
+ " <style>\n",
483
+ " .colab-df-container {\n",
484
+ " display:flex;\n",
485
+ " flex-wrap:wrap;\n",
486
+ " gap: 12px;\n",
487
+ " }\n",
488
+ "\n",
489
+ " .colab-df-convert {\n",
490
+ " background-color: #E8F0FE;\n",
491
+ " border: none;\n",
492
+ " border-radius: 50%;\n",
493
+ " cursor: pointer;\n",
494
+ " display: none;\n",
495
+ " fill: #1967D2;\n",
496
+ " height: 32px;\n",
497
+ " padding: 0 0 0 0;\n",
498
+ " width: 32px;\n",
499
+ " }\n",
500
+ "\n",
501
+ " .colab-df-convert:hover {\n",
502
+ " background-color: #E2EBFA;\n",
503
+ " box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
504
+ " fill: #174EA6;\n",
505
+ " }\n",
506
+ "\n",
507
+ " [theme=dark] .colab-df-convert {\n",
508
+ " background-color: #3B4455;\n",
509
+ " fill: #D2E3FC;\n",
510
+ " }\n",
511
+ "\n",
512
+ " [theme=dark] .colab-df-convert:hover {\n",
513
+ " background-color: #434B5C;\n",
514
+ " box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
515
+ " filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
516
+ " fill: #FFFFFF;\n",
517
+ " }\n",
518
+ " </style>\n",
519
+ "\n",
520
+ " <script>\n",
521
+ " const buttonEl =\n",
522
+ " document.querySelector('#df-d8eb48fc-bfd8-4158-80d6-468b1560edca button.colab-df-convert');\n",
523
+ " buttonEl.style.display =\n",
524
+ " google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
525
+ "\n",
526
+ " async function convertToInteractive(key) {\n",
527
+ " const element = document.querySelector('#df-d8eb48fc-bfd8-4158-80d6-468b1560edca');\n",
528
+ " const dataTable =\n",
529
+ " await google.colab.kernel.invokeFunction('convertToInteractive',\n",
530
+ " [key], {});\n",
531
+ " if (!dataTable) return;\n",
532
+ "\n",
533
+ " const docLinkHtml = 'Like what you see? Visit the ' +\n",
534
+ " '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
535
+ " + ' to learn more about interactive tables.';\n",
536
+ " element.innerHTML = '';\n",
537
+ " dataTable['output_type'] = 'display_data';\n",
538
+ " await google.colab.output.renderOutput(dataTable, element);\n",
539
+ " const docLink = document.createElement('div');\n",
540
+ " docLink.innerHTML = docLinkHtml;\n",
541
+ " element.appendChild(docLink);\n",
542
+ " }\n",
543
+ " </script>\n",
544
+ " </div>\n",
545
+ " </div>\n"
546
+ ]
547
+ },
548
+ "metadata": {},
549
+ "execution_count": 3
550
+ }
551
+ ]
552
+ },
553
+ {
554
+ "cell_type": "code",
555
+ "source": [
556
+ "def dataframe_to_dataset(df):\n",
557
+ " \"\"\"\n",
558
+ " Convert a DataFrame with columns 'disease', 'symptoms', and 'precautions'\n",
559
+ " into a list of tuples dataset.\n",
560
+ "\n",
561
+ " Parameters:\n",
562
+ " df (pd.DataFrame): Input DataFrame with columns 'disease', 'symptoms', and 'precautions'.\n",
563
+ "\n",
564
+ " Returns:\n",
565
+ " list: A list of tuples, where each tuple contains information about a specific disease,\n",
566
+ " symptoms, and precautions.\n",
567
+ " \"\"\"\n",
568
+ " if not all(col in df.columns for col in ['disease', 'symptoms', 'precautions']):\n",
569
+ " raise ValueError(\"DataFrame must contain 'disease', 'symptoms', and 'precautions' columns.\")\n",
570
+ "\n",
571
+ " dataset = []\n",
572
+ " for _, row in df.iterrows():\n",
573
+ " disease = row['disease']\n",
574
+ " symptoms = row['symptoms']\n",
575
+ " precautions = row['precautions']\n",
576
+ " dataset.append((disease, symptoms, precautions))\n",
577
+ "\n",
578
+ " return dataset\n",
579
+ "\n",
580
+ "data = dataframe_to_dataset(df)\n",
581
+ "data[:10]"
582
+ ],
583
+ "metadata": {
584
+ "colab": {
585
+ "base_uri": "https://localhost:8080/"
586
+ },
587
+ "id": "rGbPHffgNi72",
588
+ "outputId": "48542ee0-d5ab-4cae-d984-1dae31c77bd4"
589
+ },
590
+ "execution_count": 4,
591
+ "outputs": [
592
+ {
593
+ "output_type": "execute_result",
594
+ "data": {
595
+ "text/plain": [
596
+ "[('Fungal infection',\n",
597
+ " 'itching,skin_rash,nodal_skin_eruptions,dischromic__patches',\n",
598
+ " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n",
599
+ " ('Fungal infection',\n",
600
+ " 'skin_rash,nodal_skin_eruptions,dischromic__patches',\n",
601
+ " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n",
602
+ " ('Fungal infection',\n",
603
+ " 'itching,nodal_skin_eruptions,dischromic__patches',\n",
604
+ " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n",
605
+ " ('Fungal infection',\n",
606
+ " 'itching,skin_rash,dischromic__patches',\n",
607
+ " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n",
608
+ " ('Fungal infection',\n",
609
+ " 'itching,skin_rash,nodal_skin_eruptions',\n",
610
+ " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n",
611
+ " ('Fungal infection',\n",
612
+ " 'skin_rash,nodal_skin_eruptions,dischromic__patches',\n",
613
+ " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n",
614
+ " ('Fungal infection',\n",
615
+ " 'itching,nodal_skin_eruptions,dischromic__patches',\n",
616
+ " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n",
617
+ " ('Fungal infection',\n",
618
+ " 'itching,skin_rash,dischromic__patches',\n",
619
+ " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n",
620
+ " ('Fungal infection',\n",
621
+ " 'itching,skin_rash,nodal_skin_eruptions',\n",
622
+ " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n",
623
+ " ('Fungal infection',\n",
624
+ " 'itching,skin_rash,nodal_skin_eruptions,dischromic__patches',\n",
625
+ " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths')]"
626
+ ]
627
+ },
628
+ "metadata": {},
629
+ "execution_count": 4
630
+ }
631
+ ]
632
+ },
633
+ {
634
+ "cell_type": "code",
635
+ "source": [
636
+ "data[:10]"
637
+ ],
638
+ "metadata": {
639
+ "colab": {
640
+ "base_uri": "https://localhost:8080/"
641
+ },
642
+ "id": "VOG5lBemvYei",
643
+ "outputId": "9b8013a4-8273-4a51-9f1e-59566c9d4892"
644
+ },
645
+ "execution_count": 5,
646
+ "outputs": [
647
+ {
648
+ "output_type": "execute_result",
649
+ "data": {
650
+ "text/plain": [
651
+ "[('Fungal infection',\n",
652
+ " 'itching,skin_rash,nodal_skin_eruptions,dischromic__patches',\n",
653
+ " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n",
654
+ " ('Fungal infection',\n",
655
+ " 'skin_rash,nodal_skin_eruptions,dischromic__patches',\n",
656
+ " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n",
657
+ " ('Fungal infection',\n",
658
+ " 'itching,nodal_skin_eruptions,dischromic__patches',\n",
659
+ " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n",
660
+ " ('Fungal infection',\n",
661
+ " 'itching,skin_rash,dischromic__patches',\n",
662
+ " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n",
663
+ " ('Fungal infection',\n",
664
+ " 'itching,skin_rash,nodal_skin_eruptions',\n",
665
+ " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n",
666
+ " ('Fungal infection',\n",
667
+ " 'skin_rash,nodal_skin_eruptions,dischromic__patches',\n",
668
+ " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n",
669
+ " ('Fungal infection',\n",
670
+ " 'itching,nodal_skin_eruptions,dischromic__patches',\n",
671
+ " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n",
672
+ " ('Fungal infection',\n",
673
+ " 'itching,skin_rash,dischromic__patches',\n",
674
+ " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n",
675
+ " ('Fungal infection',\n",
676
+ " 'itching,skin_rash,nodal_skin_eruptions',\n",
677
+ " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths'),\n",
678
+ " ('Fungal infection',\n",
679
+ " 'itching,skin_rash,nodal_skin_eruptions,dischromic__patches',\n",
680
+ " 'bath twice, use detol or neem in bathing water, keep infected area dry, use clean cloths')]"
681
+ ]
682
+ },
683
+ "metadata": {},
684
+ "execution_count": 5
685
+ }
686
+ ]
687
+ },
688
+ {
689
+ "cell_type": "code",
690
+ "source": [
691
+ "# =============================== Training ====================================="
692
+ ],
693
+ "metadata": {
694
+ "id": "2AoLTRuhNSyp"
695
+ },
696
+ "execution_count": 7,
697
+ "outputs": []
698
+ },
699
+ {
700
+ "cell_type": "code",
701
+ "source": [
702
+ "from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AdamW\n",
703
+ "from torch.utils.data import Dataset, DataLoader\n",
704
+ "import torch\n",
705
+ "\n",
706
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
707
+ "model = model.to(device)\n",
708
+ "\n",
709
+ "# Sample Data\n",
710
+ "sample_data = data\n",
711
+ "\n",
712
+ "class CustomDataset(Dataset):\n",
713
+ " def __init__(self, data, tokenizer, max_length):\n",
714
+ " self.data = data\n",
715
+ " self.tokenizer = tokenizer\n",
716
+ " self.max_length = max_length\n",
717
+ "\n",
718
+ " def __len__(self):\n",
719
+ " return len(self.data)\n",
720
+ "\n",
721
+ " def __getitem__(self, index):\n",
722
+ " disease, symptoms, precautions = self.data[index]\n",
723
+ " source_text = f\"I am feeling {symptoms}\"\n",
724
+ " target_text = f\"You might have {disease}, the precautions are {precautions}\"\n",
725
+ "\n",
726
+ " # Tokenize the source and target texts separately\n",
727
+ " source_tokens = self.tokenizer(source_text, padding=\"max_length\", max_length=self.max_length, return_tensors=\"pt\")\n",
728
+ " target_tokens = self.tokenizer(target_text, padding=\"max_length\", max_length=self.max_length, return_tensors=\"pt\")\n",
729
+ "\n",
730
+ " # Prepare the inputs and labels for the Seq2Seq model\n",
731
+ " input_ids = source_tokens.input_ids.squeeze()\n",
732
+ " attention_mask = source_tokens.attention_mask.squeeze()\n",
733
+ " labels = target_tokens.input_ids.squeeze()\n",
734
+ "\n",
735
+ " return {\n",
736
+ " \"input_ids\": input_ids,\n",
737
+ " \"attention_mask\": attention_mask,\n",
738
+ " \"labels\": labels,\n",
739
+ " }\n",
740
+ "\n",
741
+ "def fine_tune_and_save_model(model, tokenizer):\n",
742
+ " # Load tokenizer and create dataset\n",
743
+ " # checkpoint = \"microsoft/GODEL-v1_1-large-seq2seq\"\n",
744
+ " # tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
745
+ " max_length = 128 # You can adjust this based on your input sequence length requirements\n",
746
+ " dataset = CustomDataset(sample_data, tokenizer, max_length)\n",
747
+ "\n",
748
+ " # Data loader\n",
749
+ " batch_size = 2\n",
750
+ " dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n",
751
+ "\n",
752
+ " # Load the model\n",
753
+ " # from the parameter\n",
754
+ "\n",
755
+ " # Hyperparameters\n",
756
+ " learning_rate = 2e-5\n",
757
+ " num_epochs = 2\n",
758
+ " num_warmup_steps = 100\n",
759
+ "\n",
760
+ " optimizer = AdamW(model.parameters(), lr=learning_rate)\n",
761
+ "\n",
762
+ " # Training loop\n",
763
+ " model.train()\n",
764
+ " for epoch in range(num_epochs):\n",
765
+ " total_loss = 0.0\n",
766
+ " for batch in dataloader:\n",
767
+ " optimizer.zero_grad()\n",
768
+ "\n",
769
+ " input_ids = batch[\"input_ids\"]\n",
770
+ " attention_mask = batch[\"attention_mask\"]\n",
771
+ " labels = batch[\"labels\"]\n",
772
+ "\n",
773
+ " input_ids = batch[\"input_ids\"].to(device)\n",
774
+ " attention_mask = batch[\"attention_mask\"].to(device)\n",
775
+ " labels = batch[\"labels\"].to(device)\n",
776
+ "\n",
777
+ "\n",
778
+ " outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)\n",
779
+ " loss = outputs.loss\n",
780
+ " total_loss += loss.item()\n",
781
+ "\n",
782
+ " loss.backward()\n",
783
+ " optimizer.step()\n",
784
+ "\n",
785
+ " average_loss = total_loss / len(dataloader)\n",
786
+ " print(f\"Epoch {epoch+1}/{num_epochs} - Average Loss: {average_loss:.4f}\")\n",
787
+ "\n",
788
+ " # Save the fine-tuned model and tokenizer\n",
789
+ " output_dir = \"medbot_model_epoch3_s512\"\n",
790
+ " model.save_pretrained(output_dir)\n",
791
+ " tokenizer.save_pretrained(output_dir)"
792
+ ],
793
+ "metadata": {
794
+ "id": "4COYhQqYM0ni"
795
+ },
796
+ "execution_count": 6,
797
+ "outputs": []
798
+ },
799
+ {
800
+ "cell_type": "code",
801
+ "source": [
802
+ "fine_tune_and_save_model(model, tokenizer)"
803
+ ],
804
+ "metadata": {
805
+ "colab": {
806
+ "base_uri": "https://localhost:8080/"
807
+ },
808
+ "id": "j3p4lBPbOZZP",
809
+ "outputId": "91397825-200d-4e4a-f4e7-29e9df6c040c"
810
+ },
811
+ "execution_count": 7,
812
+ "outputs": [
813
+ {
814
+ "output_type": "stream",
815
+ "name": "stderr",
816
+ "text": [
817
+ "/usr/local/lib/python3.10/dist-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
818
+ " warnings.warn(\n"
819
+ ]
820
+ },
821
+ {
822
+ "output_type": "stream",
823
+ "name": "stdout",
824
+ "text": [
825
+ "Epoch 1/2 - Average Loss: 0.1588\n",
826
+ "Epoch 2/2 - Average Loss: 0.0038\n"
827
+ ]
828
+ }
829
+ ]
830
+ },
831
+ {
832
+ "cell_type": "code",
833
+ "source": [],
834
+ "metadata": {
835
+ "id": "M9dy5RBfRcCH"
836
+ },
837
+ "execution_count": null,
838
+ "outputs": []
839
+ }
840
+ ]
841
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<extra_id_0>",
4
+ "<extra_id_1>",
5
+ "<extra_id_2>",
6
+ "<extra_id_3>",
7
+ "<extra_id_4>",
8
+ "<extra_id_5>",
9
+ "<extra_id_6>",
10
+ "<extra_id_7>",
11
+ "<extra_id_8>",
12
+ "<extra_id_9>",
13
+ "<extra_id_10>",
14
+ "<extra_id_11>",
15
+ "<extra_id_12>",
16
+ "<extra_id_13>",
17
+ "<extra_id_14>",
18
+ "<extra_id_15>",
19
+ "<extra_id_16>",
20
+ "<extra_id_17>",
21
+ "<extra_id_18>",
22
+ "<extra_id_19>",
23
+ "<extra_id_20>",
24
+ "<extra_id_21>",
25
+ "<extra_id_22>",
26
+ "<extra_id_23>",
27
+ "<extra_id_24>",
28
+ "<extra_id_25>",
29
+ "<extra_id_26>",
30
+ "<extra_id_27>",
31
+ "<extra_id_28>",
32
+ "<extra_id_29>",
33
+ "<extra_id_30>",
34
+ "<extra_id_31>",
35
+ "<extra_id_32>",
36
+ "<extra_id_33>",
37
+ "<extra_id_34>",
38
+ "<extra_id_35>",
39
+ "<extra_id_36>",
40
+ "<extra_id_37>",
41
+ "<extra_id_38>",
42
+ "<extra_id_39>",
43
+ "<extra_id_40>",
44
+ "<extra_id_41>",
45
+ "<extra_id_42>",
46
+ "<extra_id_43>",
47
+ "<extra_id_44>",
48
+ "<extra_id_45>",
49
+ "<extra_id_46>",
50
+ "<extra_id_47>",
51
+ "<extra_id_48>",
52
+ "<extra_id_49>",
53
+ "<extra_id_50>",
54
+ "<extra_id_51>",
55
+ "<extra_id_52>",
56
+ "<extra_id_53>",
57
+ "<extra_id_54>",
58
+ "<extra_id_55>",
59
+ "<extra_id_56>",
60
+ "<extra_id_57>",
61
+ "<extra_id_58>",
62
+ "<extra_id_59>",
63
+ "<extra_id_60>",
64
+ "<extra_id_61>",
65
+ "<extra_id_62>",
66
+ "<extra_id_63>",
67
+ "<extra_id_64>",
68
+ "<extra_id_65>",
69
+ "<extra_id_66>",
70
+ "<extra_id_67>",
71
+ "<extra_id_68>",
72
+ "<extra_id_69>",
73
+ "<extra_id_70>",
74
+ "<extra_id_71>",
75
+ "<extra_id_72>",
76
+ "<extra_id_73>",
77
+ "<extra_id_74>",
78
+ "<extra_id_75>",
79
+ "<extra_id_76>",
80
+ "<extra_id_77>",
81
+ "<extra_id_78>",
82
+ "<extra_id_79>",
83
+ "<extra_id_80>",
84
+ "<extra_id_81>",
85
+ "<extra_id_82>",
86
+ "<extra_id_83>",
87
+ "<extra_id_84>",
88
+ "<extra_id_85>",
89
+ "<extra_id_86>",
90
+ "<extra_id_87>",
91
+ "<extra_id_88>",
92
+ "<extra_id_89>",
93
+ "<extra_id_90>",
94
+ "<extra_id_91>",
95
+ "<extra_id_92>",
96
+ "<extra_id_93>",
97
+ "<extra_id_94>",
98
+ "<extra_id_95>",
99
+ "<extra_id_96>",
100
+ "<extra_id_97>",
101
+ "<extra_id_98>",
102
+ "<extra_id_99>"
103
+ ],
104
+ "eos_token": "</s>",
105
+ "pad_token": "<PAD>",
106
+ "unk_token": "<unk>"
107
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<extra_id_0>",
4
+ "<extra_id_1>",
5
+ "<extra_id_2>",
6
+ "<extra_id_3>",
7
+ "<extra_id_4>",
8
+ "<extra_id_5>",
9
+ "<extra_id_6>",
10
+ "<extra_id_7>",
11
+ "<extra_id_8>",
12
+ "<extra_id_9>",
13
+ "<extra_id_10>",
14
+ "<extra_id_11>",
15
+ "<extra_id_12>",
16
+ "<extra_id_13>",
17
+ "<extra_id_14>",
18
+ "<extra_id_15>",
19
+ "<extra_id_16>",
20
+ "<extra_id_17>",
21
+ "<extra_id_18>",
22
+ "<extra_id_19>",
23
+ "<extra_id_20>",
24
+ "<extra_id_21>",
25
+ "<extra_id_22>",
26
+ "<extra_id_23>",
27
+ "<extra_id_24>",
28
+ "<extra_id_25>",
29
+ "<extra_id_26>",
30
+ "<extra_id_27>",
31
+ "<extra_id_28>",
32
+ "<extra_id_29>",
33
+ "<extra_id_30>",
34
+ "<extra_id_31>",
35
+ "<extra_id_32>",
36
+ "<extra_id_33>",
37
+ "<extra_id_34>",
38
+ "<extra_id_35>",
39
+ "<extra_id_36>",
40
+ "<extra_id_37>",
41
+ "<extra_id_38>",
42
+ "<extra_id_39>",
43
+ "<extra_id_40>",
44
+ "<extra_id_41>",
45
+ "<extra_id_42>",
46
+ "<extra_id_43>",
47
+ "<extra_id_44>",
48
+ "<extra_id_45>",
49
+ "<extra_id_46>",
50
+ "<extra_id_47>",
51
+ "<extra_id_48>",
52
+ "<extra_id_49>",
53
+ "<extra_id_50>",
54
+ "<extra_id_51>",
55
+ "<extra_id_52>",
56
+ "<extra_id_53>",
57
+ "<extra_id_54>",
58
+ "<extra_id_55>",
59
+ "<extra_id_56>",
60
+ "<extra_id_57>",
61
+ "<extra_id_58>",
62
+ "<extra_id_59>",
63
+ "<extra_id_60>",
64
+ "<extra_id_61>",
65
+ "<extra_id_62>",
66
+ "<extra_id_63>",
67
+ "<extra_id_64>",
68
+ "<extra_id_65>",
69
+ "<extra_id_66>",
70
+ "<extra_id_67>",
71
+ "<extra_id_68>",
72
+ "<extra_id_69>",
73
+ "<extra_id_70>",
74
+ "<extra_id_71>",
75
+ "<extra_id_72>",
76
+ "<extra_id_73>",
77
+ "<extra_id_74>",
78
+ "<extra_id_75>",
79
+ "<extra_id_76>",
80
+ "<extra_id_77>",
81
+ "<extra_id_78>",
82
+ "<extra_id_79>",
83
+ "<extra_id_80>",
84
+ "<extra_id_81>",
85
+ "<extra_id_82>",
86
+ "<extra_id_83>",
87
+ "<extra_id_84>",
88
+ "<extra_id_85>",
89
+ "<extra_id_86>",
90
+ "<extra_id_87>",
91
+ "<extra_id_88>",
92
+ "<extra_id_89>",
93
+ "<extra_id_90>",
94
+ "<extra_id_91>",
95
+ "<extra_id_92>",
96
+ "<extra_id_93>",
97
+ "<extra_id_94>",
98
+ "<extra_id_95>",
99
+ "<extra_id_96>",
100
+ "<extra_id_97>",
101
+ "<extra_id_98>",
102
+ "<extra_id_99>"
103
+ ],
104
+ "clean_up_tokenization_spaces": true,
105
+ "eos_token": "</s>",
106
+ "extra_ids": 100,
107
+ "model_max_length": 512,
108
+ "pad_token": "<pad>",
109
+ "tokenizer_class": "T5Tokenizer",
110
+ "unk_token": "<unk>"
111
+ }