Pankaj Singh Rawat commited on
Commit
9e582c5
·
0 Parent(s):

Initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ wandb
2
+ __pycache__
3
+ transliteration
README copy.md ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Sequence to Sequence Language Transliteration using RNNs and Transformers
2
+
3
+ This repository contains the files for the third assignment of the course CS6910 - Deep Learning at IIT Madras.
4
+
5
+ The transformers part was added later and was not part of the assignment.
6
+
7
+ Implemented a Encoder Decoder Architecture with/without Attention Mechanism, and later with Transformers, and used then to perform Transliteration on the Akshanrankar Dataset(Englist-Hindi transliteration pairs) provided. These models where built using RNN, LSTM and GRU cells provided by PyTorch.
8
+
9
+ Transformers architecture is built from scratch following the "Attention is All You Need" paper. Used basic feed forward and embeddings layers from pyTorch.
10
+
11
+ Jump to Section: [Usage](#usage)
12
+
13
+ Report: [Report](https://wandb.ai/iitmadras/CS6910_Assignment_3/reports/CS6910-Assignment-3-Report--Vmlldzo0MzQyNDk5)
14
+
15
+ ## Encoder
16
+
17
+ The encoder is a simple cell of either LSTM, RNN or GRU. The input to the encoder is a sequence of characters and the output is a sequence of hidden states. The hidden state of the last time step is used as the context vector for the decoder.
18
+
19
+ Encoder can also be a transformer encoder with multiple layers containing self-attention mechanism. The output generated by the encoder is fed to the decoder of transformers.
20
+
21
+ ## Decoder
22
+
23
+ The decoder is again a simple cell of either LSTM, RNN or GRU. The input to the decoder is the hidden state of the encoder and the output of the previous time step. The output of the decoder is a sequence of characters. The decoder has an additional fully connected layer and a log softmax which is used to predict the next character.
24
+
25
+ Decoder can also be a transformer decoder with multiple layers containing masked self-attention and masked cross-attention mechanism. The output generated by the encoder is fed as input to the decoder of transformers. Next character prediction model is used to generate the complete target sequence in Hindi.
26
+
27
+ ## Attention Mechanism
28
+
29
+ The attention mechanism is implemented using the dot product attention mechanism. The attentions are calulated by a weighted sum of softmax values of dot products of the hidden states of the decoder and the hidden states of the encoder. The attention values are then concatenated with the hidden states of the decoder and passed through a fully connected layer to get the output of the decoder.
30
+
31
+ ## Dataset
32
+
33
+ The dataset used is the Aksharankar Dataset provided by the course. The dataset contains 3 files, namely, `train.csv`, `valid.csv` and `test.csv` for each language for a subset of indian languages. I have used the Tamil dataset for this assignment. The dataset contains 2 columns, namely, `English` and `Hindi` words which are the input and output strings respectively.
34
+
35
+ ## Used Python Libraries and Version
36
+
37
+ - Python 3.10.9
38
+ - Pytorch 1.13.1
39
+ - Pandas 1.5.3
40
+
41
+ ## Usage
42
+
43
+ To run the training code for the standard encoder decoder architecture using the best set of hyperparameters, run the following command:
44
+
45
+ ```bash
46
+ python3 train.py
47
+ ```
48
+
49
+ To run the training code for the encoder decoder architecture with attention mechanism using the best set of hyperparameters, run the following command:
50
+
51
+ ```bash
52
+ python3 train_attention.py
53
+ ```
54
+
55
+ To run the inference code for the standard encoder decoder architecture using the best set of hyperparameters, run the following command: (This uses the state dicts stored in the best_models folder and creates a file named test_gen.txt with the test predictions)
56
+
57
+ ```bash
58
+ python3 test_best_vanilla.py
59
+ ```
60
+
61
+ To run the inference code for the encoder decoder architecture with attention mechanism using the best set of hyperparameters, run the following command: (This uses the state dicts stored in the best_models folder and creates a file named test_gen.txt with the test predictions)
62
+
63
+ ```bash
64
+ python3 test_best_attention.py
65
+ ```
66
+
67
+ To run with custom hyperparameters, run the following command:
68
+
69
+ ```bash
70
+ python3 train.py -h
71
+ ```
72
+
73
+ ```bash
74
+ # The output of the above command is as follows:
75
+ usage: train.py [-h]
76
+ [-es EMBED_SIZE]
77
+ [-hs HIDDEN_SIZE]
78
+ [-ct CELL_TYPE]
79
+ [-nl NUM_LAYERS]
80
+ [-d DROPOUT]
81
+ [-lr LEARNING_RATE]
82
+ [-o OPTIMIZER]
83
+ [-l LANGUAGE]
84
+
85
+ Transliteration Model
86
+
87
+ options:
88
+ -h, --help show this help message and exit
89
+ -es EMBED_SIZE, --embed_size EMBED_SIZE Embedding Size, good_choices = [8, 16, 32]
90
+ -hs HIDDEN_SIZE, --hidden_size HIDDEN_SIZE Hidden Size, good_choices = [128, 256, 512]
91
+ -ct CELL_TYPE, --cell_type CELL_TYPE Cell Type, choices: [LSTM, GRU, RNN]
92
+ -nl NUM_LAYERS, --num_layers NUM_LAYERS Number of Layers, choices: [1, 2, 3]
93
+ -d DROPOUT, --dropout DROPOUT Dropout, good_choices: [0, 0.1, 0.2]
94
+ -lr LEARNING_RATE, --learning_rate LEARNING_RATE Learning Rate, good_choices: [0.0005, 0.001, 0.005]
95
+ -o OPTIMIZER, --optimizer OPTIMIZER Optimizer, choices: [SGD, ADAM]
96
+ -l LANGUAGE, --language LANGUAGE Language
97
+ ```
98
+
99
+ To run the training code for the attention mechanism with custom hyperparameters, run the following command:
100
+
101
+ ```bash
102
+ python3 train_attention.py -h
103
+ ```
104
+
105
+ ```bash
106
+ usage: train_attention.py [-h]
107
+ [-es EMBED_SIZE]
108
+ [-hs HIDDEN_SIZE]
109
+ [-ct CELL_TYPE]
110
+ [-nl NUM_LAYERS]
111
+ [-dr DROPOUT]
112
+ [-lr LEARNING_RATE]
113
+ [-op OPTIMIZER]
114
+ [-wd WEIGHT_DECAY]
115
+ [-l LANG]
116
+
117
+ Transliteration Model with Attention
118
+
119
+ options:
120
+ -h, --help show this help message and exit
121
+ -es EMBED_SIZE, --embed_size EMBED_SIZE Embedding size
122
+ -hs HIDDEN_SIZE, --hidden_size HIDDEN_SIZE Hidden size
123
+ -ct CELL_TYPE, --cell_type CELL_TYPE Cell type
124
+ -nl NUM_LAYERS, --num_layers NUM_LAYERS Number of layers
125
+ -dr DROPOUT, --dropout DROPOUT Dropout
126
+ -lr LEARNING_RATE, --learning_rate LEARNING_RATE Learning rate
127
+ -op OPTIMIZER, --optimizer OPTIMIZER Optimizer
128
+ -wd WEIGHT_DECAY, --weight_decay WEIGHT_DECAY Weight decay
129
+ -l LANG, --lang LANG Language
130
+ ```
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Transliteration
3
+ emoji: 📈
4
+ colorFrom: green
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 4.44.1
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from inference.language import Language
3
+ from inference.utility import Encoder, Decoder, encoderBlock, decoderBlock, MultiHeadAttention, Head, FeedForward
4
+ from inference.transformer import generate
5
+
6
+ # Function to call the FastAPI backend
7
+ def predict(user_input):
8
+ # Prepare the data to send to the FastAPI API
9
+ input = user_input.split(" ")
10
+
11
+ result = generate(input)
12
+
13
+ # Extract the answer
14
+ return " ".join(result)
15
+
16
+
17
+ # Launch the Gradio interface
18
+ if __name__ == "__main__":
19
+ gr.Interface(predict,
20
+ inputs=gr.Textbox(placeholder="Your Hinglish text"),
21
+ outputs=gr.Textbox(placeholder="Output Hindi text"),
22
+ description="A English to Hindi Transliteration app",
23
+ examples=["namaste aapko", "kese ho aap, sab badiya"]).launch(share=False)
fast_api.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from inference.language import Language
4
+ from inference.utility import Encoder, Decoder, encoderBlock, decoderBlock, MultiHeadAttention, Head, FeedForward
5
+ from inference.transformer import generate
6
+ from typing import List
7
+ import uvicorn
8
+
9
+
10
+ # Initialize FastAPI app
11
+ app = FastAPI()
12
+
13
+ # Create a request model to define the input for the transliteration pipeline
14
+ class TransRequest(BaseModel):
15
+ query: str
16
+
17
+ # Create a response model to define the output of the RAG pipeline
18
+ class TransResponse(BaseModel):
19
+ response: List[str]
20
+
21
+ # Define a FastAPI endpoint for transliteration pipeline
22
+ @app.post("/trans", response_model=TransResponse)
23
+ async def get_transliteration(request: TransRequest):
24
+ try:
25
+ # Call the RAG pipeline function with the query
26
+ input = request.query.split(" ")
27
+ result = generate(input)
28
+
29
+ return TransResponse(
30
+ response=result
31
+ )
32
+ except Exception as e:
33
+ # In case of an error, return an HTTPException with a 500 status code
34
+ raise HTTPException(status_code=500, detail=str(e))
35
+
36
+ # Run the FastAPI application (for local testing)
37
+ if __name__ == "__main__":
38
+ uvicorn.run(app, host="127.0.0.1", port=8000)
inference/__init__.py ADDED
File without changes
inference/decoder (1).pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:66c2e34a0036672d568c9b5faa09bcd80f75f0edeea6053aeacb20b8094aace6
3
+ size 14137430
inference/demo.ipynb ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 23,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import torch\n",
10
+ "import pickle\n",
11
+ "from language import Language\n",
12
+ "from utility import Encoder, Decoder, encoderBlock, decoderBlock, MultiHeadAttention, Head, FeedForward\n",
13
+ "import warnings\n",
14
+ "from typing import List\n",
15
+ "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n"
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "code",
20
+ "execution_count": 43,
21
+ "metadata": {},
22
+ "outputs": [],
23
+ "source": []
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": 44,
28
+ "metadata": {},
29
+ "outputs": [
30
+ {
31
+ "data": {
32
+ "text/plain": [
33
+ "['लॉक्सलाक्राक्यालालासी']"
34
+ ]
35
+ },
36
+ "execution_count": 44,
37
+ "metadata": {},
38
+ "output_type": "execute_result"
39
+ }
40
+ ],
41
+ "source": [
42
+ "s = 'a' * 1\n",
43
+ "generate([s])"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": 39,
49
+ "metadata": {},
50
+ "outputs": [
51
+ {
52
+ "data": {
53
+ "text/plain": [
54
+ "'^थ्रालाष्राप्टोार्फ्रास्रफ्फ्फ्'"
55
+ ]
56
+ },
57
+ "execution_count": 39,
58
+ "metadata": {},
59
+ "output_type": "execute_result"
60
+ }
61
+ ],
62
+ "source": [
63
+ "output_lang.decode(o.tolist()[0])"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": 28,
69
+ "metadata": {},
70
+ "outputs": [
71
+ {
72
+ "data": {
73
+ "text/plain": [
74
+ "tensor([20, 4, 5, 12, 4, 3])"
75
+ ]
76
+ },
77
+ "execution_count": 28,
78
+ "metadata": {},
79
+ "output_type": "execute_result"
80
+ }
81
+ ],
82
+ "source": [
83
+ "s = \"pankaj\"\n",
84
+ "torch.tensor(input_lang.encode(s), device=device, dtype=torch.long)"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "execution_count": 5,
90
+ "metadata": {},
91
+ "outputs": [
92
+ {
93
+ "name": "stdout",
94
+ "output_type": "stream",
95
+ "text": [
96
+ "Running on local URL: http://127.0.0.1:7864\n",
97
+ "\n",
98
+ "Could not create share link. Please check your internet connection or our status page: https://status.gradio.app.\n"
99
+ ]
100
+ },
101
+ {
102
+ "data": {
103
+ "text/html": [
104
+ "<div><iframe src=\"http://127.0.0.1:7864/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
105
+ ],
106
+ "text/plain": [
107
+ "<IPython.core.display.HTML object>"
108
+ ]
109
+ },
110
+ "metadata": {},
111
+ "output_type": "display_data"
112
+ }
113
+ ],
114
+ "source": [
115
+ "import requests\n",
116
+ "import gradio as gr\n",
117
+ "\n",
118
+ "# Define the API endpoint\n",
119
+ "API_URL = \"http://127.0.0.1:8000/trans\"\n",
120
+ "\n",
121
+ "# Function to call the FastAPI backend\n",
122
+ "def predict(user_input):\n",
123
+ " # Prepare the data to send to the FastAPI API\n",
124
+ " payload = {\"query\": user_input}\n",
125
+ " \n",
126
+ " # Make a request to the FastAPI backend\n",
127
+ " response = requests.post(API_URL, json=payload)\n",
128
+ " \n",
129
+ " # Get the response JSON\n",
130
+ " result = response.json()\n",
131
+ " \n",
132
+ " # Extract the answer \n",
133
+ " return \" \".join(result[\"response\"])\n",
134
+ " \n",
135
+ "\n",
136
+ "# Launch the Gradio interface\n",
137
+ "if __name__ == \"__main__\":\n",
138
+ " gr.Interface(predict,\n",
139
+ " inputs=['textbox'],\n",
140
+ " outputs=['text']).launch(share=True)\n"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": null,
146
+ "metadata": {},
147
+ "outputs": [],
148
+ "source": [
149
+ "import gradio as gr\n",
150
+ "\n",
151
+ "def greet(name, intensity):\n",
152
+ " return \"Hello, \" + name + \"!\" * int(intensity)\n",
153
+ "\n",
154
+ "demo = gr.Interface(\n",
155
+ " fn=greet,\n",
156
+ " inputs=[\"text\", \"slider\"],\n",
157
+ " outputs=[\"text\"],\n",
158
+ ")\n",
159
+ "\n",
160
+ "demo.launch()\n"
161
+ ]
162
+ }
163
+ ],
164
+ "metadata": {
165
+ "kernelspec": {
166
+ "display_name": "transliteration",
167
+ "language": "python",
168
+ "name": "python3"
169
+ },
170
+ "language_info": {
171
+ "codemirror_mode": {
172
+ "name": "ipython",
173
+ "version": 3
174
+ },
175
+ "file_extension": ".py",
176
+ "mimetype": "text/x-python",
177
+ "name": "python",
178
+ "nbconvert_exporter": "python",
179
+ "pygments_lexer": "ipython3",
180
+ "version": "3.12.7"
181
+ }
182
+ },
183
+ "nbformat": 4,
184
+ "nbformat_minor": 2
185
+ }
inference/encoder (1).pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b433380e98834313e590be330b6c2543eb8131ce043922fd95d37def151d1e7c
3
+ size 9799034
inference/input_lang.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9671af04c3790c0ae6183278806e948f2674db3506d3b1a1c65b924a4adbec78
3
+ size 396
inference/language.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Language:
2
+ def __init__(self, name):
3
+ self.name = name
4
+ self.char2index = {'#': 0, '$': 1, '^': 2} # '^': start of sequence, '$' : unknown char, '#' : padding
5
+ self.index2char = {0: '#', 1: '$', 2: '^'}
6
+ self.vocab_size = 3 # Count
7
+
8
+ def addWord(self, word):
9
+ for char in word:
10
+ self.addChar(char)
11
+
12
+ def addChar(self, char):
13
+ if char not in self.char2index:
14
+ self.char2index[char] = self.vocab_size
15
+ self.index2char[self.vocab_size] = char
16
+ self.vocab_size += 1
17
+
18
+ def encode(self, s):
19
+ return [self.char2index[ch] for ch in s]
20
+
21
+ def decode(self, l):
22
+ return ''.join([self.index2char[i] for i in l])
23
+
24
+ def vocab(self):
25
+ return self.char2index.keys()
inference/output_lang.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:412857775d61629580d62df66b7e663bab247467f9a975920df3b1265acd6326
3
+ size 964
inference/transformer.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pickle
3
+ import sys
4
+ import os
5
+ from inference.language import Language
6
+ from inference.utility import Encoder, Decoder, encoderBlock, decoderBlock, MultiHeadAttention, Head, FeedForward
7
+ import warnings
8
+ from typing import List
9
+ warnings.filterwarnings("ignore", category=FutureWarning)
10
+
11
+
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+ with open(os.path.join(os.path.dirname(__file__), 'input_lang.pkl'), "rb") as file:
15
+ input_lang = pickle.load(file)
16
+
17
+ with open(os.path.join(os.path.dirname(__file__), 'output_lang.pkl'), "rb") as file:
18
+ output_lang = pickle.load(file)
19
+
20
+ encoder = torch.load(os.path.join(os.path.dirname(__file__), 'encoder (1).pth'), map_location=device)
21
+ decoder = torch.load(os.path.join(os.path.dirname(__file__), 'decoder (1).pth'), map_location=device)
22
+
23
+ input_vocab_size = input_lang.vocab_size
24
+ output_vocab_size = output_lang.vocab_size
25
+
26
+ def encode(s):
27
+ return [input_lang.char2index.get(ch, input_lang.char2index['$']) for ch in s]
28
+
29
+ def generate(input: List[str]) -> List[str]:
30
+ # pre-process the input: same length and max_length = 33
31
+ for i, inp in enumerate(input):
32
+ input[i] = input[i][:33] if len(input[i]) > 33 else input[i].ljust(33, '#')
33
+
34
+ input = torch.tensor([encode(i) for i in input], device=device, dtype=torch.long)
35
+ B, T = input.shape
36
+
37
+ encoder_output = encoder(input)
38
+ idx = torch.full((B, 1), 2, dtype=torch.long, device=device) # (B,1)
39
+
40
+ # idx is (B, T) array of indices in the current context
41
+ for _ in range(30):
42
+ # get the predictions
43
+ logits, loss = decoder(idx, encoder_output) # logits (B, T, vocab_size)
44
+ # focus only on the last time step
45
+ logits = logits[:, -1, :] # becomes (B, C)
46
+ # apply softmax to get probabilities
47
+ idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (B, 1)
48
+ # append sampled index to the running sequence
49
+ idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
50
+
51
+ ans = []
52
+ for id in idx:
53
+ ans.append(output_lang.decode(id.tolist()[1:]).split('#', 1)[0])
54
+ return ans
inference/utility.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6
+ encoder_block_size = 33
7
+ decoder_block_size = 30
8
+
9
+ class Head(nn.Module):
10
+ """ one self-attention head """
11
+
12
+ def __init__(self, n_embd, d_k, dropout, mask=0): # d_k is dimention of key , nomaly d_k = n_embd / 4
13
+ super().__init__()
14
+ self.mask = mask
15
+ self.key = nn.Linear(n_embd, d_k, bias=False, device=device)
16
+ self.query = nn.Linear(n_embd, d_k, bias=False, device=device)
17
+ self.value = nn.Linear(n_embd, d_k, bias=False, device=device)
18
+ if mask:
19
+ self.register_buffer('tril', torch.tril(torch.ones(encoder_block_size, encoder_block_size, device=device)))
20
+ self.dropout = nn.Dropout(dropout)
21
+
22
+ def forward(self, x, encoder_output = None):
23
+ B,T,C = x.shape
24
+
25
+ if encoder_output is not None:
26
+ k = self.key(encoder_output)
27
+ Be, Te, Ce = encoder_output.shape
28
+ else:
29
+ k = self.key(x) # (B,T,d_k)
30
+
31
+ q = self.query(x) # (B,T,d_k)
32
+ # compute attention scores
33
+ wei = q @ k.transpose(-2, -1) * C**-0.5 # (B,T,T)
34
+
35
+ if self.mask:
36
+ if encoder_output is not None:
37
+ wei = wei.masked_fill(self.tril[:T, :Te] == 0, float('-inf')) # (B,T,T)
38
+ else:
39
+ wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B,T,T)
40
+
41
+ wei = F.softmax(wei, dim=-1)
42
+ wei = self.dropout(wei)
43
+ # perform weighted aggregation of values
44
+ if encoder_output is not None:
45
+ v = self.value(encoder_output)
46
+ else:
47
+ v = self.value(x)
48
+ out = wei @ v # (B,T,C)
49
+ return out
50
+
51
+ class MultiHeadAttention(nn.Module):
52
+ """ multiple self attention heads in parallel """
53
+
54
+ def __init__(self, n_embd, num_head, d_k, dropout, mask=0):
55
+ super().__init__()
56
+ self.heads = nn.ModuleList([Head(n_embd, d_k, dropout, mask) for _ in range(num_head)])
57
+ self.proj = nn.Linear(n_embd, n_embd)
58
+ self.dropout = nn.Dropout(dropout)
59
+
60
+ def forward(self, x, encoder_output=None):
61
+ out = torch.cat([h(x, encoder_output) for h in self.heads], dim=-1)
62
+ out = self.dropout(self.proj(out))
63
+ return out
64
+
65
+ class FeedForward(nn.Module):
66
+ """ multiple self attention heads in parallel """
67
+
68
+ def __init__(self, n_embd, dropout):
69
+ super().__init__()
70
+ self.net = nn.Sequential(
71
+ nn.Linear(n_embd, 4 * n_embd),
72
+ nn.ReLU(),
73
+ nn.Linear(4 * n_embd, n_embd),
74
+ nn.Dropout(dropout)
75
+ )
76
+
77
+ def forward(self, x):
78
+ return self.net(x)
79
+
80
+ class encoderBlock(nn.Module):
81
+ """ Tranformer encoder block : communication followed by computation """
82
+
83
+ def __init__(self, n_embd, n_head, dropout):
84
+ super().__init__()
85
+ d_k = n_embd // n_head
86
+ self.sa = MultiHeadAttention(n_embd, n_head, d_k, dropout)
87
+ self.ffwd = FeedForward(n_embd, dropout)
88
+ self.ln1 = nn.LayerNorm(n_embd)
89
+ self.ln2 = nn.LayerNorm(n_embd)
90
+
91
+ def forward(self, x, encoder_output=None):
92
+ x = x + self.sa(self.ln1(x), encoder_output)
93
+ x = x + self.ffwd(self.ln2(x))
94
+ return x
95
+
96
+ class Encoder(nn.Module):
97
+
98
+ def __init__(self, n_embd, n_head, n_layers, dropout):
99
+ super().__init__()
100
+
101
+ self.token_embedding_table = nn.Embedding(input_vocab_size, n_embd) # n_embd: input embedding dimension
102
+ self.position_embedding_table = nn.Embedding(encoder_block_size, n_embd)
103
+ self.blocks = nn.Sequential(*[encoderBlock(n_embd, n_head, dropout) for _ in range(n_layers)])
104
+ self.ln_f = nn.LayerNorm(n_embd) # final layer norm
105
+
106
+ def forward(self, idx):
107
+ B, T = idx.shape
108
+ tok_emb = self.token_embedding_table(idx) # (B,T,n_embd)
109
+ pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,n_embd)
110
+ x = tok_emb + pos_emb # (B,T,n_embd)
111
+ x = self.blocks(x) # apply one attention layer (B,T,C)
112
+ x = self.ln_f(x) # (B,T,C)
113
+ return x
114
+
115
+
116
+ class decoderBlock(nn.Module):
117
+ """ Tranformer decoder block : self communication then cross communication followed by computation """
118
+
119
+ def __init__(self, n_embd, n_head, dropout):
120
+ super().__init__()
121
+ d_k = n_embd // n_head
122
+ self.sa = MultiHeadAttention(n_embd, n_head, d_k, dropout, mask = 1)
123
+ self.ca = MultiHeadAttention(n_embd, n_head, d_k, dropout, mask = 1)
124
+ self.ffwd = FeedForward(n_embd, dropout)
125
+ self.ln1 = nn.LayerNorm(n_embd, device=device)
126
+ self.ln2 = nn.LayerNorm(n_embd, device=device)
127
+ self.ln3 = nn.LayerNorm(n_embd, device=device)
128
+
129
+ def forward(self, x_encoder_output):
130
+ x = x_encoder_output[0]
131
+ encoder_output = x_encoder_output[1]
132
+ x = x + self.sa(self.ln1(x))
133
+ x = x + self.ca(self.ln2(x), encoder_output)
134
+ x = x + self.ffwd(self.ln3(x))
135
+ return (x,encoder_output)
136
+
137
+ class Decoder(nn.Module):
138
+
139
+ def __init__(self, n_embd, n_head, n_layers, dropout):
140
+ super().__init__()
141
+
142
+ self.token_embedding_table = nn.Embedding(output_vocab_size, n_embd) # n_embd: input embedding dimension
143
+ self.position_embedding_table = nn.Embedding(decoder_block_size, n_embd)
144
+ self.blocks = nn.Sequential(*[decoderBlock(n_embd, n_head=n_head, dropout=dropout) for _ in range(n_layers)])
145
+ self.ln_f = nn.LayerNorm(n_embd) # final layer norm
146
+ self.lm_head = nn.Linear(n_embd, output_vocab_size)
147
+
148
+ def forward(self, idx, encoder_output, targets=None):
149
+ B, T = idx.shape
150
+
151
+ tok_emb = self.token_embedding_table(idx) # (B,T,n_embd)
152
+ pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,n_embd)
153
+ x = tok_emb + pos_emb # (B,T,n_embd)
154
+
155
+ x =self.blocks((x, encoder_output))
156
+ x = self.ln_f(x[0]) # (B,T,C)
157
+ logits = self.lm_head(x) # (B,T,output_vocab_size)
158
+
159
+ if targets is None:
160
+ loss = None
161
+ else:
162
+ B, T, C = logits.shape
163
+ temp_logits = logits.view(B*T, C)
164
+ targets = targets.reshape(B*T)
165
+
166
+ loss = F.cross_entropy(temp_logits, targets.long())
167
+
168
+ # print(logits)
169
+ # out = torch.argmax(logits)
170
+
171
+ return logits, loss
172
+
main.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import gradio as gr
3
+
4
+ # Define the API endpoint
5
+ API_URL = "http://127.0.0.1:8000/trans"
6
+
7
+ # Function to call the FastAPI backend
8
+ def predict(user_input):
9
+ # Prepare the data to send to the FastAPI API
10
+ payload = {"query": user_input}
11
+
12
+ # Make a request to the FastAPI backend
13
+ response = requests.post(API_URL, json=payload)
14
+
15
+ # Get the response JSON
16
+ result = response.json()
17
+
18
+ # Extract the answer
19
+ return " ".join(result["response"])
20
+
21
+
22
+ # Launch the Gradio interface
23
+ if __name__ == "__main__":
24
+ gr.Interface(predict,
25
+ inputs=['textbox'],
26
+ outputs=['text']).launch(share=True)
notebooks/encoder_decoder_RNNs.ipynb ADDED
@@ -0,0 +1,1924 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "colab_type": "text",
7
+ "id": "view-in-github"
8
+ },
9
+ "source": [
10
+ "<a href=\"https://colab.research.google.com/github/pankajrawat9075/CS6910_assignment_3/blob/main/DL_PA3_final.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "metadata": {
16
+ "id": "hRdpoWePeYHn"
17
+ },
18
+ "source": [
19
+ "## Importing Libraries and models"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": null,
25
+ "metadata": {
26
+ "id": "0LBvFtYGCNgJ"
27
+ },
28
+ "outputs": [],
29
+ "source": [
30
+ "%%capture\n",
31
+ "!pip install wandb"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": null,
37
+ "metadata": {
38
+ "id": "zkZTzr7OCPBM"
39
+ },
40
+ "outputs": [],
41
+ "source": [
42
+ "import wandb"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": null,
48
+ "metadata": {
49
+ "id": "z4ZVrIumZcDt"
50
+ },
51
+ "outputs": [],
52
+ "source": [
53
+ "from __future__ import unicode_literals, print_function, division\n",
54
+ "from io import open\n",
55
+ "import unicodedata\n",
56
+ "import string\n",
57
+ "import re\n",
58
+ "import random\n",
59
+ "import pandas as pd\n",
60
+ "import torch\n",
61
+ "import torch.nn as nn\n",
62
+ "from torch import optim\n",
63
+ "import torch.nn.functional as F\n",
64
+ "\n",
65
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
66
+ "torch.cuda.empty_cache()"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": null,
72
+ "metadata": {
73
+ "colab": {
74
+ "base_uri": "https://localhost:8080/"
75
+ },
76
+ "id": "qwL09v65CIse",
77
+ "outputId": "f1dcbc80-5110-48f9-d0c5-836a2daa05b4"
78
+ },
79
+ "outputs": [
80
+ {
81
+ "name": "stdout",
82
+ "output_type": "stream",
83
+ "text": [
84
+ "cuda\n"
85
+ ]
86
+ }
87
+ ],
88
+ "source": [
89
+ "print(device)"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "markdown",
94
+ "metadata": {
95
+ "id": "44xIRolL_T_d"
96
+ },
97
+ "source": [
98
+ "## Load Dataset"
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "code",
103
+ "execution_count": null,
104
+ "metadata": {
105
+ "colab": {
106
+ "base_uri": "https://localhost:8080/"
107
+ },
108
+ "id": "-XRMpx9eBzRK",
109
+ "outputId": "177ee7ae-bb7d-46ea-9269-fa3aa045a89e"
110
+ },
111
+ "outputs": [
112
+ {
113
+ "name": "stdout",
114
+ "output_type": "stream",
115
+ "text": [
116
+ "Mounted at /content/drive\n"
117
+ ]
118
+ }
119
+ ],
120
+ "source": [
121
+ "from google.colab import drive\n",
122
+ "drive.mount('/content/drive')"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "code",
127
+ "execution_count": null,
128
+ "metadata": {
129
+ "id": "Y4zemXiyE6Fi"
130
+ },
131
+ "outputs": [],
132
+ "source": [
133
+ "class Lang:\n",
134
+ " def __init__(self, name):\n",
135
+ " self.name = name\n",
136
+ " self.char2index = {'#': 0, '$': 1, '^': 2}\n",
137
+ " self.char2count = {'#': 1, '$': 1, '^': 1}\n",
138
+ " self.index2char = {0: '#', 1: '$', 2: '^'}\n",
139
+ " self.n_chars = 3 # Count\n",
140
+ " self.data = {}\n",
141
+ " \n",
142
+ "\n",
143
+ " def addWord(self, word):\n",
144
+ " for char in word:\n",
145
+ " self.addChar(char)\n",
146
+ "\n",
147
+ " def addChar(self, char):\n",
148
+ " if char not in self.char2index:\n",
149
+ " self.char2index[char] = self.n_chars\n",
150
+ " self.char2count[char] = 1\n",
151
+ " self.index2char[self.n_chars] = char\n",
152
+ " self.n_chars += 1\n",
153
+ " else:\n",
154
+ " self.char2count[char] += 1\n",
155
+ "\n",
156
+ " \n"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "metadata": {
163
+ "id": "dCR658yRvXpy"
164
+ },
165
+ "outputs": [],
166
+ "source": [
167
+ "# return max length of input and output words\n",
168
+ "def maxLength(data):\n",
169
+ " ip_mlen, op_mlen = 0, 0\n",
170
+ "\n",
171
+ " for i in range(len(data)):\n",
172
+ " input = data[0][i]\n",
173
+ " output = data[1][i]\n",
174
+ " if(len(input)>ip_mlen):\n",
175
+ " ip_mlen=len(input)\n",
176
+ "\n",
177
+ " if(len(output)>op_mlen):\n",
178
+ " op_mlen=len(output)\n",
179
+ "\n",
180
+ " return ip_mlen, op_mlen"
181
+ ]
182
+ },
183
+ {
184
+ "cell_type": "code",
185
+ "execution_count": null,
186
+ "metadata": {
187
+ "id": "IDGaCO8DkYpc"
188
+ },
189
+ "outputs": [],
190
+ "source": [
191
+ "import numpy\n",
192
+ "input_shape = 0\n",
193
+ "from torch.utils.data import TensorDataset, DataLoader\n",
194
+ "def preprocess(data, input_lang, output_lang):\n",
195
+ " maxlenInput, maxlenOutput = maxLength(data)\n",
196
+ " # we use maxlenInput as 26 since it is the maximum of all input len\n",
197
+ " maxlenInput = 26\n",
198
+ " input = numpy.zeros((len(data), maxlenInput + 1))\n",
199
+ " output = numpy.zeros((len(data), maxlenOutput + 2))\n",
200
+ " maxlenInput, maxlenOutput = maxLength(data)\n",
201
+ " unknown = input_lang.char2index['$']\n",
202
+ "\n",
203
+ " for i in range(len(data)):\n",
204
+ " op = '^' + data[1][i]\n",
205
+ " ip = data[0][i].ljust(maxlenInput + 1, '#')\n",
206
+ " op = op.ljust(maxlenOutput + 2, '#')\n",
207
+ " \n",
208
+ "\n",
209
+ " for index, char in enumerate(ip):\n",
210
+ " if input_lang.char2index.get(char) is not None:\n",
211
+ " input[i][index] = input_lang.char2index[char]\n",
212
+ " else:\n",
213
+ " input[i][index] = unknown\n",
214
+ " \n",
215
+ "\n",
216
+ " \n",
217
+ " for index, char in enumerate(op):\n",
218
+ " if output_lang.char2index.get(char) is not None:\n",
219
+ " output[i][index] = output_lang.char2index[char]\n",
220
+ " else:\n",
221
+ " output[i][index] = unknown \n",
222
+ "\n",
223
+ " print(input.shape)\n",
224
+ " print(output.shape)\n",
225
+ "\n",
226
+ " return TensorDataset(torch.from_numpy(input), torch.from_numpy(output))"
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "code",
231
+ "execution_count": null,
232
+ "metadata": {
233
+ "colab": {
234
+ "base_uri": "https://localhost:8080/"
235
+ },
236
+ "id": "PdS5OXKxfdCX",
237
+ "outputId": "178f1d73-5b0c-431d-ca9b-d9435b924c41"
238
+ },
239
+ "outputs": [
240
+ {
241
+ "name": "stdout",
242
+ "output_type": "stream",
243
+ "text": [
244
+ "(51200, 27)\n",
245
+ "(51200, 22)\n",
246
+ "(4096, 27)\n",
247
+ "(4096, 22)\n",
248
+ "(4096, 27)\n",
249
+ "(4096, 22)\n"
250
+ ]
251
+ }
252
+ ],
253
+ "source": [
254
+ "def loadData(lang):\n",
255
+ " train_df = pd.read_csv(f\"drive/MyDrive/aksharantar_sampled/{lang}/{lang}_train.csv\", header = None)\n",
256
+ " val_df = pd.read_csv(f\"drive/MyDrive/aksharantar_sampled/{lang}/{lang}_valid.csv\", header = None)\n",
257
+ " test_df = pd.read_csv(f\"drive/MyDrive/aksharantar_sampled/{lang}/{lang}_test.csv\", header = None)\n",
258
+ "\n",
259
+ " input_lang = Lang('eng')\n",
260
+ " output_lang = Lang(lang)\n",
261
+ " \n",
262
+ " # add the words to the respective languages\n",
263
+ " for i in range(len(train_df)):\n",
264
+ " \n",
265
+ " input_lang.addWord(train_df[0][i])\n",
266
+ " output_lang.addWord(train_df[1][i])\n",
267
+ "\n",
268
+ " # print(input_lang.char2index)\n",
269
+ " # print(input_lang.index2char)\n",
270
+ " trainDataset = preprocess(train_df, input_lang, output_lang)\n",
271
+ " testDataset = preprocess(test_df, input_lang, output_lang)\n",
272
+ " valDataset = preprocess(val_df, input_lang, output_lang)\n",
273
+ "\n",
274
+ " return trainDataset, testDataset, valDataset, input_lang, output_lang\n",
275
+ "\n",
276
+ "\n",
277
+ "trainData, testData, valData, ipLang, opLang = loadData('hin')\n"
278
+ ]
279
+ },
280
+ {
281
+ "cell_type": "code",
282
+ "execution_count": null,
283
+ "metadata": {
284
+ "colab": {
285
+ "base_uri": "https://localhost:8080/"
286
+ },
287
+ "id": "SvmzS5Lt_Jnl",
288
+ "outputId": "33defb60-5aee-46cb-e683-ee2df9e98436"
289
+ },
290
+ "outputs": [
291
+ {
292
+ "name": "stderr",
293
+ "output_type": "stream",
294
+ "text": [
295
+ "\u001b[34m\u001b[1mwandb\u001b[0m: W&B API key is configured. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n",
296
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m If you're specifying your api key in code, ensure this code is not shared publicly.\n",
297
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Consider setting the WANDB_API_KEY environment variable, or running `wandb login` from the command line.\n",
298
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n"
299
+ ]
300
+ },
301
+ {
302
+ "data": {
303
+ "text/plain": [
304
+ "True"
305
+ ]
306
+ },
307
+ "execution_count": 10,
308
+ "metadata": {},
309
+ "output_type": "execute_result"
310
+ }
311
+ ],
312
+ "source": [
313
+ "wandb.login(key =\"\")"
314
+ ]
315
+ },
316
+ {
317
+ "cell_type": "markdown",
318
+ "metadata": {
319
+ "id": "Q1TioafYgICa"
320
+ },
321
+ "source": [
322
+ "# seq2seq model"
323
+ ]
324
+ },
325
+ {
326
+ "cell_type": "markdown",
327
+ "metadata": {
328
+ "id": "svxssm9Havhb"
329
+ },
330
+ "source": [
331
+ "## Encoder"
332
+ ]
333
+ },
334
+ {
335
+ "cell_type": "code",
336
+ "execution_count": null,
337
+ "metadata": {
338
+ "id": "YTwk8nKNcbkb"
339
+ },
340
+ "outputs": [],
341
+ "source": [
342
+ "class EncoderRNN(nn.Module):\n",
343
+ " def __init__(self, input_size, hidden_size, embedding_size, # input_size is size of input language dictionary\n",
344
+ " num_layers, cell_type,\n",
345
+ " bidirectional, dropout, batch_size) :\n",
346
+ " super(EncoderRNN, self).__init__()\n",
347
+ " self.hidden_size = hidden_size # size of an hidden state representation\n",
348
+ " self.num_layers = num_layers \n",
349
+ " self.bidirectional = True if bidirectional == 'Yes' else False\n",
350
+ " self.batch_size = batch_size\n",
351
+ " self.cell_type = cell_type\n",
352
+ " self.embedding_size=embedding_size\n",
353
+ "\n",
354
+ " # this adds the embedding layer\n",
355
+ " self.embedding = nn.Embedding(num_embeddings=input_size,embedding_dim= embedding_size)\n",
356
+ " self.dropout = nn.Dropout(dropout)\n",
357
+ "\n",
358
+ " # this adds the Neural Network layer for the encoder\n",
359
+ " if self.cell_type == \"GRU\":\n",
360
+ " self.rnn = nn.GRU(embedding_size, hidden_size, num_layers=num_layers, bidirectional=self.bidirectional, dropout=dropout)\n",
361
+ " elif self.cell_type == \"LSTM\":\n",
362
+ " self.rnn = nn.LSTM(embedding_size, hidden_size, num_layers=num_layers, bidirectional=self.bidirectional, dropout=dropout)\n",
363
+ " else:\n",
364
+ " self.rnn = nn.RNN(embedding_size, hidden_size, num_layers=num_layers, bidirectional=self.bidirectional, dropout=dropout)\n",
365
+ "\n",
366
+ " def forward(self, input, hidden): # input shape (seq_len, batch_size) hidden shape tuple for lstm, otherwise single\n",
367
+ " embedded = self.embedding(input.long()).view(-1,self.batch_size, self.embedding_size)\n",
368
+ " output = self.dropout(embedded) # output shape (seq_len, batch_size, embedding size)\n",
369
+ "\n",
370
+ " output, hidden = self.rnn(output, hidden) # for LSTM hidden is a tuple\n",
371
+ " if self.bidirectional:\n",
372
+ " if self.cell_type == \"LSTM\":\n",
373
+ " hidden_state = hidden[0].resize(2,self.num_layers,self.batch_size,self.hidden_size)\n",
374
+ " cell_state = hidden[1].resize(2,self.num_layers,self.batch_size,self.hidden_size)\n",
375
+ " hidden = (torch.add(hidden_state[0],hidden_state[1])/2, torch.add(cell_state[0],cell_state[1])/2)\n",
376
+ " else:\n",
377
+ " hidden=hidden.resize(2,self.num_layers,self.batch_size,self.hidden_size)\n",
378
+ " hidden=torch.add(hidden[0],hidden[1])/2\n",
379
+ " \n",
380
+ " split_tensor= torch.split(output, self.hidden_size, dim=-1)\n",
381
+ " output=torch.add(split_tensor[0],split_tensor[1])/2\n",
382
+ " return output, hidden\n",
383
+ "\n",
384
+ " # initializing the initial hidden state for the encoder\n",
385
+ " def initHidden(self):\n",
386
+ " num_directions = 2 if self.bidirectional else 1\n",
387
+ " if self.cell_type == \"LSTM\":\n",
388
+ " return (torch.zeros(self.num_layers * num_directions, self.batch_size, self.hidden_size, device=device),\n",
389
+ " torch.zeros(self.num_layers * num_directions, self.batch_size, self.hidden_size, device=device))\n",
390
+ " else:\n",
391
+ " return torch.zeros(self.num_layers * num_directions, self.batch_size, self.hidden_size, device=device)\n"
392
+ ]
393
+ },
394
+ {
395
+ "cell_type": "markdown",
396
+ "metadata": {
397
+ "id": "J56aq1J6a07q"
398
+ },
399
+ "source": [
400
+ "## Decoder"
401
+ ]
402
+ },
403
+ {
404
+ "cell_type": "code",
405
+ "execution_count": null,
406
+ "metadata": {
407
+ "id": "53ki6eJUH2u2"
408
+ },
409
+ "outputs": [],
410
+ "source": [
411
+ "class DecoderRNN(nn.Module):\n",
412
+ " def __init__(self, hidden_size, output_size, embedding_size, num_layers, # output size is the size of output language dictionary\n",
413
+ " cell_type, dropout, batch_size):\n",
414
+ " super(DecoderRNN, self).__init__()\n",
415
+ " self.hidden_size = hidden_size\n",
416
+ " self.num_layers = num_layers\n",
417
+ " self.cell_type = cell_type.lower()\n",
418
+ " self.batch_size = batch_size\n",
419
+ " self.embedding_size=embedding_size\n",
420
+ "\n",
421
+ " self.embedding = nn.Embedding(output_size, embedding_size)\n",
422
+ " # self.dropout = nn.Dropout(dropout)\n",
423
+ " \n",
424
+ " if self.cell_type == \"gru\":\n",
425
+ " self.rnn = nn.GRU(embedding_size, hidden_size, num_layers=num_layers)\n",
426
+ " elif self.cell_type == \"lstm\":\n",
427
+ " self.rnn = nn.LSTM(embedding_size, hidden_size, num_layers=num_layers)\n",
428
+ " else:\n",
429
+ " self.rnn = nn.RNN(embedding_size, hidden_size, num_layers=num_layers)\n",
430
+ "\n",
431
+ " self.out = nn.Linear(hidden_size, output_size)\n",
432
+ " self.softmax = nn.LogSoftmax(dim=2)\n",
433
+ "\n",
434
+ " def forward(self, input, hidden): # input shape (1, batch_size)\n",
435
+ " embedded = self.embedding(input.long()).view(-1, self.batch_size, self.embedding_size)\n",
436
+ " # # shape (1, batch_size, embedding_size)\n",
437
+ " output = F.relu(embedded)\n",
438
+ " output, hidden = self.rnn(output, hidden) # output shape (1, batch_size, hidden_size)\n",
439
+ " output = self.softmax(self.out(output)) # shape (1, batch_size, output_size)\n",
440
+ " return output, hidden\n",
441
+ "\n",
442
+ " # not needed since hidden will be provided by the encoder"
443
+ ]
444
+ },
445
+ {
446
+ "cell_type": "markdown",
447
+ "metadata": {
448
+ "id": "5JcQdylzI_Fc"
449
+ },
450
+ "source": [
451
+ "## Attention Decoder"
452
+ ]
453
+ },
454
+ {
455
+ "cell_type": "code",
456
+ "execution_count": null,
457
+ "metadata": {
458
+ "id": "R1Xysuv9I-Qr"
459
+ },
460
+ "outputs": [],
461
+ "source": [
462
+ "class AttentionDecoderRNN(nn.Module):\n",
463
+ " def __init__(self, hidden_size, output_size, embedding_size, num_layers,\n",
464
+ " cell_type, dropout, batch_size, max_length):\n",
465
+ " super(AttentionDecoderRNN, self).__init__()\n",
466
+ " self.hidden_size = hidden_size\n",
467
+ " self.num_layers = num_layers\n",
468
+ " self.cell_type = cell_type\n",
469
+ " self.batch_size = batch_size\n",
470
+ " self.embedding_size = embedding_size\n",
471
+ " self.max_length = max_length\n",
472
+ " self.dropout = dropout\n",
473
+ "\n",
474
+ " self.embedding = nn.Embedding(output_size, embedding_size)\n",
475
+ " self.dropout = nn.Dropout(self.dropout)\n",
476
+ " self.attention = nn.Linear(hidden_size + embedding_size, self.max_length)\n",
477
+ " self.attention_combine = nn.Linear(hidden_size + embedding_size, hidden_size)\n",
478
+ "\n",
479
+ " if self.cell_type == \"GRU\":\n",
480
+ " self.rnn = nn.GRU(hidden_size, hidden_size, num_layers=num_layers)\n",
481
+ " elif self.cell_type == \"LSTM\":\n",
482
+ " self.rnn = nn.LSTM(hidden_size, hidden_size, num_layers=num_layers)\n",
483
+ " else:\n",
484
+ " self.rnn = nn.RNN(hidden_size, hidden_size, num_layers=num_layers)\n",
485
+ "\n",
486
+ " self.out = nn.Linear(hidden_size, output_size)\n",
487
+ " self.softmax = nn.LogSoftmax(dim=2)\n",
488
+ "\n",
489
+ " def forward(self, input, hidden, encoder_outputs): #input shape (1, batch_size)\n",
490
+ " embedded = self.embedding(input.long()).view(-1, self.batch_size, self.embedding_size) \n",
491
+ " # embedded shape (1, batch_size, embedding_size)\n",
492
+ " embedded = F.relu(embedded)\n",
493
+ "\n",
494
+ " # Compute attention scores\n",
495
+ " if self.cell_type == \"LSTM\":\n",
496
+ " attn_hidden = torch.mean(hidden[0], dim=0)\n",
497
+ " else:\n",
498
+ " attn_hidden = torch.mean(hidden, dim = 0)\n",
499
+ " attn_scores = self.attention(torch.cat((embedded, attn_hidden.unsqueeze(0)), dim=2)) # attn_scores shape (1, batch_size, max_length)\n",
500
+ " \n",
501
+ " attn_weights = F.softmax(attn_scores, dim=-1) # attn_scores shape (1, 16, 25)\n",
502
+ " \n",
503
+ "\n",
504
+ " # Apply attention weights to encoder outputs\n",
505
+ " attn_applied = torch.bmm(attn_weights.transpose(0, 1), encoder_outputs.transpose(0, 1))\n",
506
+ " \n",
507
+ " # Combine attention output and embedded input\n",
508
+ " combined = torch.cat((embedded, attn_applied.transpose(0, 1)), dim=2)\n",
509
+ " combined = self.attention_combine(combined)\n",
510
+ " combined = F.relu(combined) # shape (1, batch_size, hidden_size)\n",
511
+ "\n",
512
+ " # Run through the RNN\n",
513
+ " output, hidden = self.rnn(combined, hidden)\n",
514
+ " # output shape: (1, batch_size, hidden_size)\n",
515
+ "\n",
516
+ " # Pass through linear layer and softmax activation\n",
517
+ " output = self.out(output) # shape: (1, batch_size, output_size)\n",
518
+ " output = self.softmax(output)\n",
519
+ " return output, hidden, attn_weights.transpose(0, 1)\n"
520
+ ]
521
+ },
522
+ {
523
+ "cell_type": "code",
524
+ "execution_count": null,
525
+ "metadata": {
526
+ "id": "LJ2Papj_jTX8"
527
+ },
528
+ "outputs": [],
529
+ "source": []
530
+ },
531
+ {
532
+ "cell_type": "markdown",
533
+ "metadata": {
534
+ "id": "658W9RARGEUf"
535
+ },
536
+ "source": [
537
+ "# Helper functions"
538
+ ]
539
+ },
540
+ {
541
+ "cell_type": "markdown",
542
+ "metadata": {
543
+ "id": "q7fAgs5uQni_"
544
+ },
545
+ "source": [
546
+ "## count matches"
547
+ ]
548
+ },
549
+ {
550
+ "cell_type": "code",
551
+ "execution_count": null,
552
+ "metadata": {
553
+ "id": "8fzy8U6_lbug"
554
+ },
555
+ "outputs": [],
556
+ "source": [
557
+ "def count_exact_matches(pred, target):\n",
558
+ " \"\"\"\n",
559
+ " Counts the number of rows in preds tensor that match exactly with each row in y tensor.\n",
560
+ " pred: tensor of shape (batch_size, seq_len-1)\n",
561
+ " y: tensor of shape (batch_size, seq_len-1)\n",
562
+ " \"\"\"\n",
563
+ " \n",
564
+ " count=0;\n",
565
+ " for i in range(pred.shape[0]):\n",
566
+ " flag = True\n",
567
+ " for j in range(pred.shape[1]):\n",
568
+ " if(target[i][j]!=pred[i][j]):\n",
569
+ " flag=False\n",
570
+ " break;\n",
571
+ " \n",
572
+ " if(flag):\n",
573
+ " count+=1;\n",
574
+ " \n",
575
+ " return count"
576
+ ]
577
+ },
578
+ {
579
+ "cell_type": "markdown",
580
+ "metadata": {
581
+ "id": "n4rGh7vuQqaa"
582
+ },
583
+ "source": [
584
+ "## evaluation"
585
+ ]
586
+ },
587
+ {
588
+ "cell_type": "code",
589
+ "execution_count": null,
590
+ "metadata": {
591
+ "id": "zp6gvWmDlWoB"
592
+ },
593
+ "outputs": [],
594
+ "source": [
595
+ "def evaluate(data,encoder, decoder,output_size,batch_size,hidden_size,num_layers_encoder,num_layers_decoder, cell_type, attention):\n",
596
+ " \n",
597
+ "\n",
598
+ "\n",
599
+ " running_loss = 0\n",
600
+ " correct =0\n",
601
+ " \n",
602
+ " loader = DataLoader(data, batch_size=batch_size)\n",
603
+ " loss_fun = nn.CrossEntropyLoss(reduction=\"sum\")\n",
604
+ " seq_len = 0\n",
605
+ "\n",
606
+ " atten_weights = torch.zeros(1,21, 27).to(device) # required to return the attention weights\n",
607
+ " predictions = torch.zeros(22-1, 1).to(device)\n",
608
+ " with torch.no_grad():\n",
609
+ " for j,(x,y) in enumerate(loader):\n",
610
+ " loss=0\n",
611
+ " encoder.eval()\n",
612
+ " decoder.eval()\n",
613
+ "\n",
614
+ " x = x.to(device)\n",
615
+ " y = y.to(device)\n",
616
+ "\n",
617
+ " x = x.T\n",
618
+ " y = y.T\n",
619
+ " seq_len = len(y)\n",
620
+ " \n",
621
+ " encoder_hidden=encoder.initHidden()\n",
622
+ " encoder_output,encoder_hidden = encoder(x,encoder_hidden)\n",
623
+ " \n",
624
+ " \n",
625
+ " decoder_input =y[0]\n",
626
+ " \n",
627
+ " # Handle different numbers of layers in the encoder and decoder\n",
628
+ " if num_layers_encoder != num_layers_decoder:\n",
629
+ " if num_layers_encoder < num_layers_decoder:\n",
630
+ " remaining_layers = num_layers_decoder - num_layers_encoder\n",
631
+ "\n",
632
+ " # Copy all encoder hidden layers and then repeat the top layer\n",
633
+ " if cell_type == \"LSTM\":\n",
634
+ " top_layer_hidden = (encoder_hidden[0][-1].unsqueeze(0), encoder_hidden[1][-1].unsqueeze(0))\n",
635
+ " extra_hidden = (top_layer_hidden[0].repeat(remaining_layers, 1, 1), top_layer_hidden[1].repeat(remaining_layers, 1, 1))\n",
636
+ " decoder_hidden = (torch.cat((encoder_hidden[0], extra_hidden[0]), dim=0), torch.cat((encoder_hidden[1], extra_hidden[1]), dim=0))\n",
637
+ " else:\n",
638
+ " top_layer_hidden = encoder_hidden[-1].unsqueeze(0) #top_layer_hidden shape (1, batch_size, hidden_size)\n",
639
+ " extra_hidden = top_layer_hidden.repeat(remaining_layers, 1, 1)\n",
640
+ " decoder_hidden = torch.cat((encoder_hidden, extra_hidden), dim=0)\n",
641
+ "\n",
642
+ " else:\n",
643
+ " # Slice the hidden states of the encoder to match the decoder layers\n",
644
+ " if cell_type == \"LSTM\":\n",
645
+ " decoder_hidden = (encoder_hidden[0][-num_layers_decoder:], encoder_hidden[1][-num_layers_decoder:])\n",
646
+ " else :\n",
647
+ " decoder_hidden = encoder_hidden[-num_layers_decoder:]\n",
648
+ " else:\n",
649
+ " decoder_hidden = encoder_hidden\n",
650
+ "\n",
651
+ " pred=torch.zeros(len(y)-1, batch_size).to(device)\n",
652
+ " atten_weight_default = torch.zeros(batch_size,1, 27).to(device)\n",
653
+ " for k in range(1,len(y)):\n",
654
+ " if attention == \"Yes\":\n",
655
+ " \n",
656
+ " decoder_output, decoder_hidden, atten_weight = decoder(decoder_input, decoder_hidden, encoder_output)\n",
657
+ " atten_weight_default = torch.cat((atten_weight_default, atten_weight), dim = 1)\n",
658
+ " else:\n",
659
+ " decoder_output, decoder_hidden= decoder(decoder_input, decoder_hidden)\n",
660
+ " max_prob, index = decoder_output.topk(1) # max_prob shape (1, batch_size, 1)\n",
661
+ " decoder_output = torch.squeeze(decoder_output)\n",
662
+ " loss += loss_fun(decoder_output, y[k].long())\n",
663
+ " pred[k-1]= torch.squeeze(index)\n",
664
+ " decoder_input = index\n",
665
+ " if attention == \"Yes\":\n",
666
+ " atten_weights = torch.cat((atten_weights, atten_weight_default[:, 1:, :]), dim = 0)\n",
667
+ "\n",
668
+ " running_loss += loss.item()\n",
669
+ " correct += count_exact_matches(pred.T,y[1:,:].T)\n",
670
+ " predictions = torch.cat((predictions, pred), dim=1)\n",
671
+ "\n",
672
+ " \n",
673
+ " avg_loss = running_loss / (len(data) * seq_len)\n",
674
+ " print(\"correct =\", correct)\n",
675
+ " avg_acc = 100 * (correct / (len(data)))\n",
676
+ " if attention == \"Yes\":\n",
677
+ " return avg_loss, avg_acc, predictions, atten_weights[1:, :, :]\n",
678
+ " else:\n",
679
+ " return avg_loss, avg_acc, predictions\n",
680
+ " \n",
681
+ " \n",
682
+ " "
683
+ ]
684
+ },
685
+ {
686
+ "cell_type": "markdown",
687
+ "metadata": {
688
+ "id": "0SsnRWlgQmCI"
689
+ },
690
+ "source": [
691
+ "# Training function"
692
+ ]
693
+ },
694
+ {
695
+ "cell_type": "code",
696
+ "execution_count": null,
697
+ "metadata": {
698
+ "id": "PhDgsZG0QqPW"
699
+ },
700
+ "outputs": [],
701
+ "source": [
702
+ "def train(sweeps = True, test = False):\n",
703
+ "\n",
704
+ " if sweeps == False: \n",
705
+ " configs = config_defaults # use the default configuration which has the best hyperparameters\n",
706
+ " else:\n",
707
+ " wandb.init(config= config_defaults, project='DL_assign_3') # if not test then run wandb sweeps\n",
708
+ " configs=wandb.config\n",
709
+ " \n",
710
+ "\n",
711
+ " learn_rate = configs['learn_rate']\n",
712
+ " batch_size = configs['batch_size']\n",
713
+ " hidden_size = configs['hidden_size']\n",
714
+ " embedding_size = configs['embedding_size']\n",
715
+ " num_layers_encoder = configs['num_layers_encoder']\n",
716
+ " num_layers_decoder = configs['num_layers_decoder']\n",
717
+ " cell_type = configs['cell_type']\n",
718
+ " bidirectional = configs['bidirectional']\n",
719
+ " dropout = configs['dropout']\n",
720
+ " teach_ratio = configs['teach_ratio']\n",
721
+ " epochs = configs['epochs']\n",
722
+ " attention = configs['attention']\n",
723
+ "\n",
724
+ " if sweeps:\n",
725
+ " wandb.run.name='hidden_'+str(hidden_size)+'_batch_'+str(batch_size)+'_embed_size_'+str(embedding_size)+'_dropout_'+str(dropout)+'_cell_'+str(cell_type)\n",
726
+ "\n",
727
+ " input_len = ipLang.n_chars\n",
728
+ " output_len = opLang.n_chars\n",
729
+ " \n",
730
+ " encoder = EncoderRNN(input_len, hidden_size, embedding_size, \n",
731
+ " num_layers_encoder, cell_type,\n",
732
+ " bidirectional, dropout, batch_size)\n",
733
+ " \n",
734
+ " if attention ==\"Yes\":\n",
735
+ " decoder = AttentionDecoderRNN(hidden_size, output_len, embedding_size, num_layers_decoder, \n",
736
+ " cell_type, dropout, batch_size, 27)\n",
737
+ " else:\n",
738
+ " decoder = DecoderRNN(hidden_size, output_len, embedding_size, num_layers_decoder, \n",
739
+ " cell_type, dropout, batch_size)#dropout not used\n",
740
+ " \n",
741
+ " train_loader = DataLoader(trainData, batch_size=batch_size, shuffle=True)\n",
742
+ " val_loader = DataLoader(valData, batch_size=batch_size, shuffle=True)\n",
743
+ "\n",
744
+ " encoder_optimizer=optim.Adam(encoder.parameters(),learn_rate)\n",
745
+ " decoder_optimizer=optim.Adam(decoder.parameters(),learn_rate)\n",
746
+ " loss_fun=nn.CrossEntropyLoss(reduction=\"sum\")\n",
747
+ "\n",
748
+ " encoder.to(device)\n",
749
+ " decoder.to(device)\n",
750
+ " seq_len = 0\n",
751
+ "\n",
752
+ " # Initialize variables for early stopping\n",
753
+ " best_val_loss = float('inf')\n",
754
+ " patience = 5\n",
755
+ " epochs_without_improvement = 0\n",
756
+ "\n",
757
+ " for i in range(epochs):\n",
758
+ " \n",
759
+ " running_loss = 0.0\n",
760
+ " train_correct = 0\n",
761
+ "\n",
762
+ " encoder.train()\n",
763
+ " decoder.train()\n",
764
+ "\n",
765
+ " for j,(train_x,train_y) in enumerate(train_loader):\n",
766
+ " train_x = train_x.to(device)\n",
767
+ " train_y = train_y.to(device)\n",
768
+ "\n",
769
+ " encoder_optimizer.zero_grad()\n",
770
+ " decoder_optimizer.zero_grad()\n",
771
+ "\n",
772
+ " train_x=train_x.T\n",
773
+ " train_y=train_y.T\n",
774
+ " # print(\"train_x.shapetrain_x.shape)\n",
775
+ " seq_len = len(train_y)\n",
776
+ " encoder_hidden=encoder.initHidden()\n",
777
+ " # for LSTM encoder_hidden shape ((num_layers * num_directions, batch_size,hidden_size),(self.num_layers * num_directions, batch_size, hidden_size))\n",
778
+ " encoder_output,encoder_hidden = encoder(train_x,encoder_hidden)\n",
779
+ " # encoder_hidden shape (num_layers, batch_size, hidden_size)\n",
780
+ " \n",
781
+ " \n",
782
+ " # lets move to the decoder\n",
783
+ " decoder_input = train_y[0] # shape (1, batch_size)\n",
784
+ " \n",
785
+ " # Handle different numbers of layers in the encoder and decoder\n",
786
+ " if num_layers_encoder != num_layers_decoder:\n",
787
+ " if num_layers_encoder < num_layers_decoder:\n",
788
+ " remaining_layers = num_layers_decoder - num_layers_encoder\n",
789
+ " # Copy all encoder hidden layers and then repeat the top layer\n",
790
+ " if cell_type == \"LSTM\":\n",
791
+ " top_layer_hidden = (encoder_hidden[0][-1].unsqueeze(0), encoder_hidden[1][-1].unsqueeze(0))\n",
792
+ " extra_hidden = (top_layer_hidden[0].repeat(remaining_layers, 1, 1), top_layer_hidden[1].repeat(remaining_layers, 1, 1))\n",
793
+ " decoder_hidden = (torch.cat((encoder_hidden[0], extra_hidden[0]), dim=0), torch.cat((encoder_hidden[1], extra_hidden[1]), dim=0))\n",
794
+ " else:\n",
795
+ " top_layer_hidden = encoder_hidden[-1].unsqueeze(0) #top_layer_hidden shape (1, batch_size, hidden_size)\n",
796
+ " extra_hidden = top_layer_hidden.repeat(remaining_layers, 1, 1)\n",
797
+ " decoder_hidden = torch.cat((encoder_hidden, extra_hidden), dim=0)\n",
798
+ " \n",
799
+ " else:\n",
800
+ " # Slice the hidden states of the encoder to match the decoder layers\n",
801
+ " if cell_type == \"LSTM\":\n",
802
+ " decoder_hidden = (encoder_hidden[0][-num_layers_decoder:], encoder_hidden[1][-num_layers_decoder:])\n",
803
+ " else :\n",
804
+ " decoder_hidden = encoder_hidden[-num_layers_decoder:]\n",
805
+ " else:\n",
806
+ " decoder_hidden = encoder_hidden\n",
807
+ " \n",
808
+ " loss = 0\n",
809
+ " correct = 0\n",
810
+ " \n",
811
+ " for k in range(0, len(train_y)-1):\n",
812
+ " \n",
813
+ " if attention == \"Yes\":\n",
814
+ " decoder_output, decoder_hidden, atten_weights = decoder(decoder_input, decoder_hidden, encoder_output)\n",
815
+ " else:\n",
816
+ " decoder_output, decoder_hidden= decoder(decoder_input, decoder_hidden) # decoder_output shape (1, batch_size, output_size)\n",
817
+ "\n",
818
+ " max_prob, index = decoder_output.topk(1) # max_prob shape (1, batch_size, 1)\n",
819
+ " index = torch.squeeze(index) # shape (batch_size)\n",
820
+ " decoder_output = torch.squeeze(decoder_output)\n",
821
+ " loss += loss_fun(decoder_output, train_y[k+1].long())\n",
822
+ " \n",
823
+ " correct += (index == train_y[k+1]).sum().item()\n",
824
+ "\n",
825
+ " # Apply teacher forcing\n",
826
+ " use_teacher_forcing = True if random.random() < teach_ratio else False\n",
827
+ "\n",
828
+ " if use_teacher_forcing:\n",
829
+ " decoder_input = train_y[k+1]\n",
830
+ " \n",
831
+ " else:\n",
832
+ " decoder_input = index\n",
833
+ "\n",
834
+ " running_loss += loss.item()\n",
835
+ " train_correct += correct\n",
836
+ " loss.backward()\n",
837
+ " encoder_optimizer.step()\n",
838
+ " decoder_optimizer.step()\n",
839
+ " \n",
840
+ "\n",
841
+ " # find train loss and accuracy and print + log to wandb\n",
842
+ " if attention == \"Yes\":\n",
843
+ " _, train_accuracy,_, _ = evaluate(trainData,encoder, decoder,output_len,batch_size,hidden_size,num_layers_encoder,num_layers_decoder, cell_type, attention)\n",
844
+ " else:\n",
845
+ " _, train_accuracy,_= evaluate(trainData,encoder, decoder,output_len,batch_size,hidden_size,num_layers_encoder,num_layers_decoder, cell_type, attention)\n",
846
+ " \n",
847
+ " print(f\"epoch {i}, training loss {running_loss/(len(trainData)* seq_len)}, training accuracy {train_accuracy}\")\n",
848
+ " if sweeps:\n",
849
+ " wandb.log({\"epoch\": i, \"train_loss\": running_loss/(len(trainData)* seq_len), \"train_accuracy\": train_accuracy})\n",
850
+ " \n",
851
+ " # # find validation loss and accuracy and print + log to wandb\n",
852
+ " if attention == \"Yes\":\n",
853
+ " val_loss, val_accuracy,_, _ = evaluate(valData,encoder, decoder,output_len,batch_size,hidden_size,num_layers_encoder,num_layers_decoder, cell_type, attention)\n",
854
+ " else:\n",
855
+ " val_loss, val_accuracy,_ = evaluate(valData,encoder, decoder,output_len,batch_size,hidden_size,num_layers_encoder,num_layers_decoder, cell_type, attention)\n",
856
+ " \n",
857
+ " print(f\"epoch {i}, validation loss {val_loss}, validation accuracy {val_accuracy}\")\n",
858
+ " if sweeps:\n",
859
+ " wandb.log({\"val_loss\": val_loss, \"val_accuracy\": val_accuracy})\n",
860
+ "\n",
861
+ " # Check for early stopping\n",
862
+ " if val_loss < best_val_loss:\n",
863
+ " best_val_loss = val_loss\n",
864
+ " epochs_without_improvement = 0\n",
865
+ " # Save the model weights\n",
866
+ " torch.save(encoder.state_dict(), 'best_encoder.pt')\n",
867
+ " torch.save(decoder.state_dict(), 'best_decoder.pt')\n",
868
+ " else:\n",
869
+ " epochs_without_improvement += 1\n",
870
+ " if epochs_without_improvement >= patience:\n",
871
+ " print(\"Early stopping triggered. No improvement in validation loss.\")\n",
872
+ " break\n",
873
+ " \n",
874
+ " \n",
875
+ " # if testing mode is on print the test accuracy \n",
876
+ " if test:\n",
877
+ " # Load the best model weights\n",
878
+ " encoder.load_state_dict(torch.load('best_encoder.pt'))\n",
879
+ " decoder.load_state_dict(torch.load('best_decoder.pt'))\n",
880
+ " if attention == \"Yes\":\n",
881
+ " _, test_accuracy, pred, atten_weights = evaluate(testData,encoder, decoder,output_len,batch_size,hidden_size,num_layers_encoder,num_layers_decoder, cell_type, attention)\n",
882
+ " else:\n",
883
+ " _, test_accuracy, pred = evaluate(testData,encoder, decoder,output_len,batch_size,hidden_size,num_layers_encoder,num_layers_decoder, cell_type, attention)\n",
884
+ " print(f\"test accuracy {test_accuracy}\")\n",
885
+ "\n",
886
+ " if attention == \"Yes\":\n",
887
+ " return pred, atten_weights\n",
888
+ " else:\n",
889
+ " return pred\n",
890
+ " "
891
+ ]
892
+ },
893
+ {
894
+ "cell_type": "markdown",
895
+ "metadata": {
896
+ "id": "nvyRJWUUbR2f"
897
+ },
898
+ "source": [
899
+ "# Translating predictions to words\n"
900
+ ]
901
+ },
902
+ {
903
+ "cell_type": "code",
904
+ "execution_count": null,
905
+ "metadata": {
906
+ "id": "Hd3zCTnSbSaL"
907
+ },
908
+ "outputs": [],
909
+ "source": [
910
+ "def translate_prediction(input_dict , input, output_dict, pred,target):\n",
911
+ " \n",
912
+ " '''pred in shape of seq_len-1 * dataset_size\n",
913
+ " target in shape datasize * seq_len-1\n",
914
+ " '''\n",
915
+ " pred = pred.T # shape datasize * seq len-1\n",
916
+ " pred = pred[1:, :-1] # ignore last index of each row\n",
917
+ " input = input[:, :-1] # ignore last index of each row\n",
918
+ " target = target[:, 1:-1] # ignore last index of each row\n",
919
+ " print(f\"pred shape {pred.shape}, input shape {input.shape}, target shape {target.shape}\")\n",
920
+ " predictions = [] \n",
921
+ " Input = [] \n",
922
+ " Target = []\n",
923
+ " for i in range(len(pred)):\n",
924
+ " \n",
925
+ " pred_word=\"\"\n",
926
+ " input_word=\"\"\n",
927
+ " target_word = \"\"\n",
928
+ "\n",
929
+ " for j in range(pred.shape[1]):\n",
930
+ "\n",
931
+ " # Ignore padding\n",
932
+ " if(target[i][j].item() != 0):\n",
933
+ " \n",
934
+ " pred_word += output_dict[pred[i][j].item()]\n",
935
+ " target_word += output_dict[target[i][j].item()]\n",
936
+ " \n",
937
+ " for j in range(input.shape[1]):\n",
938
+ " \n",
939
+ " if(input[i][j].item()!=0):\n",
940
+ " \n",
941
+ " input_word += input_dict[input[i][j].item()] \n",
942
+ "\n",
943
+ " # Append words in respective List\n",
944
+ " \n",
945
+ " predictions.append(pred_word)\n",
946
+ " Input.append(input_word) \n",
947
+ " Target.append(target_word) \n",
948
+ "\n",
949
+ " # Create a DataFrame\n",
950
+ " df = pd.DataFrame({\"input\": Input, \"predicted\": predictions,\"Actual\":Target})\n",
951
+ " return df\n",
952
+ "\n",
953
+ " "
954
+ ]
955
+ },
956
+ {
957
+ "cell_type": "markdown",
958
+ "metadata": {
959
+ "id": "8ETW0BG_Pa24"
960
+ },
961
+ "source": [
962
+ "#call train"
963
+ ]
964
+ },
965
+ {
966
+ "cell_type": "code",
967
+ "execution_count": null,
968
+ "metadata": {
969
+ "id": "pgGp7MoGzfPg"
970
+ },
971
+ "outputs": [],
972
+ "source": [
973
+ "# train(sweeps = False, test = True)"
974
+ ]
975
+ },
976
+ {
977
+ "cell_type": "markdown",
978
+ "metadata": {
979
+ "id": "MQPGy32rnD3V"
980
+ },
981
+ "source": [
982
+ "# Runnning sweeps for models without Attention\n",
983
+ "\n"
984
+ ]
985
+ },
986
+ {
987
+ "cell_type": "markdown",
988
+ "metadata": {
989
+ "id": "z_aYZvDD1OHU"
990
+ },
991
+ "source": [
992
+ "## Sweep Config"
993
+ ]
994
+ },
995
+ {
996
+ "cell_type": "code",
997
+ "execution_count": null,
998
+ "metadata": {
999
+ "id": "SVv8bI-D1Q_I"
1000
+ },
1001
+ "outputs": [],
1002
+ "source": [
1003
+ "sweep_config = {\n",
1004
+ " 'name': 'sweepDL', \n",
1005
+ " 'method': 'bayes',\n",
1006
+ " 'metric': {\n",
1007
+ " 'name': 'val_accuracy',\n",
1008
+ " 'goal': 'maximize'\n",
1009
+ " },\n",
1010
+ " 'parameters': {\n",
1011
+ " \n",
1012
+ " 'learn_rate': {\n",
1013
+ " 'values': [0.01, 0.001, 0.001]\n",
1014
+ " },\n",
1015
+ " 'embedding_size': {\n",
1016
+ " 'values': [32, 64, 128, 256, 512, 1024]\n",
1017
+ " },\n",
1018
+ " 'batch_size':{\n",
1019
+ " 'values':[16, 32, 64, 128, 256]\n",
1020
+ " },\n",
1021
+ " 'hidden_size':{\n",
1022
+ " 'values':[32, 64, 128, 256, 512, 1024]\n",
1023
+ " },\n",
1024
+ " 'teach_ratio':{\n",
1025
+ " 'values':[0.4, 0.5, 0.6]\n",
1026
+ " },\n",
1027
+ " 'dropout':{\n",
1028
+ " 'values':[0, 0.2, 0.4]\n",
1029
+ " },\n",
1030
+ " 'cell_type':{\n",
1031
+ " 'values':[\"RNN\", \"LSTM\", \"GRU\"]\n",
1032
+ " },\n",
1033
+ " 'bidirectional':{\n",
1034
+ " 'values' : [\"Yes\",\"No\"]\n",
1035
+ " },\n",
1036
+ " 'num_layers_decoder':{\n",
1037
+ " 'values': [1,2, 3, 4]\n",
1038
+ " },\n",
1039
+ " 'num_layers_encoder':{\n",
1040
+ " 'values': [1,2,3,4]\n",
1041
+ " },\n",
1042
+ " 'epochs':{\n",
1043
+ " 'values': [10, 15, 20, 25, 30]\n",
1044
+ " },\n",
1045
+ " 'attention':{\n",
1046
+ " 'values': [\"Yes\"]\n",
1047
+ " }\n",
1048
+ " \n",
1049
+ " }\n",
1050
+ "}\n",
1051
+ "config_defaults={\n",
1052
+ " 'learn_rate' : 0.001,\n",
1053
+ " 'embedding_size': 32,\n",
1054
+ " 'batch_size': 256,\n",
1055
+ " 'hidden_size' : 1024,\n",
1056
+ " 'num_layers_encoder': 3,\n",
1057
+ " 'num_layers_decoder': 3,\n",
1058
+ " 'bidirectional': 'No',\n",
1059
+ " 'cell_type': \"LSTM\",\n",
1060
+ " 'teach_ratio': 0.6,\n",
1061
+ " 'dropout': 0.4,\n",
1062
+ " 'epochs': 15,\n",
1063
+ " 'attention': \"No\"\n",
1064
+ "}"
1065
+ ]
1066
+ },
1067
+ {
1068
+ "cell_type": "code",
1069
+ "execution_count": null,
1070
+ "metadata": {
1071
+ "id": "4KxsOOpvr1oi"
1072
+ },
1073
+ "outputs": [],
1074
+ "source": [
1075
+ "sweep_id=wandb.sweep(sweep_config, project=\"CS6910_Assignment_3\")\n",
1076
+ "wandb.agent(sweep_id,function=train)"
1077
+ ]
1078
+ },
1079
+ {
1080
+ "cell_type": "markdown",
1081
+ "metadata": {
1082
+ "id": "pKvBd5mKf0Hf"
1083
+ },
1084
+ "source": [
1085
+ "# Testing the Best Model(without Attention) on Test Data \n",
1086
+ "Set default hyperparameters to the best hyperparameters got from sweeps Hyperparamer tuning"
1087
+ ]
1088
+ },
1089
+ {
1090
+ "cell_type": "code",
1091
+ "execution_count": null,
1092
+ "metadata": {
1093
+ "id": "kMQvZjZl0q4U"
1094
+ },
1095
+ "outputs": [],
1096
+ "source": [
1097
+ "config_defaults={\n",
1098
+ " 'learn_rate' : 0.001,\n",
1099
+ " 'embedding_size': 32,\n",
1100
+ " 'batch_size': 256,\n",
1101
+ " 'hidden_size' : 1024,\n",
1102
+ " 'num_layers_encoder': 3,\n",
1103
+ " 'num_layers_decoder': 3,\n",
1104
+ " 'bidirectional': 'No',\n",
1105
+ " 'cell_type': \"LSTM\",\n",
1106
+ " 'teach_ratio': 0.6,\n",
1107
+ " 'dropout': 0.4,\n",
1108
+ " 'epochs': 15,\n",
1109
+ " 'attention': \"No\"\n",
1110
+ "}"
1111
+ ]
1112
+ },
1113
+ {
1114
+ "cell_type": "code",
1115
+ "execution_count": null,
1116
+ "metadata": {
1117
+ "colab": {
1118
+ "base_uri": "https://localhost:8080/"
1119
+ },
1120
+ "id": "ygtFpEvp8jFU",
1121
+ "outputId": "1a71d3be-f17f-498c-8844-3c115c411f0a"
1122
+ },
1123
+ "outputs": [
1124
+ {
1125
+ "name": "stdout",
1126
+ "output_type": "stream",
1127
+ "text": [
1128
+ "correct = 1490\n",
1129
+ "test accuracy 36.376953125\n"
1130
+ ]
1131
+ }
1132
+ ],
1133
+ "source": [
1134
+ "pred= train(sweeps = False, test = True)"
1135
+ ]
1136
+ },
1137
+ {
1138
+ "cell_type": "markdown",
1139
+ "metadata": {
1140
+ "id": "hMf0OAuscOJx"
1141
+ },
1142
+ "source": [
1143
+ "# Saving the predictions by Vanilla model in csv file"
1144
+ ]
1145
+ },
1146
+ {
1147
+ "cell_type": "code",
1148
+ "execution_count": null,
1149
+ "metadata": {
1150
+ "colab": {
1151
+ "base_uri": "https://localhost:8080/"
1152
+ },
1153
+ "id": "1cgUOUdsfzUB",
1154
+ "outputId": "8784a3aa-315e-476f-cced-c38ebb8434b3"
1155
+ },
1156
+ "outputs": [
1157
+ {
1158
+ "name": "stdout",
1159
+ "output_type": "stream",
1160
+ "text": [
1161
+ "pred shape torch.Size([4096, 20]), input shape torch.Size([4096, 26]), target shape torch.Size([4096, 20])\n"
1162
+ ]
1163
+ }
1164
+ ],
1165
+ "source": [
1166
+ "# save the predictions\n",
1167
+ "dataframe = translate_prediction(ipLang.index2char, testData[:][0], opLang.index2char, pred, testData[:][1])\n",
1168
+ "dataframe.to_csv(\"predictions.csv\")"
1169
+ ]
1170
+ },
1171
+ {
1172
+ "cell_type": "code",
1173
+ "execution_count": null,
1174
+ "metadata": {
1175
+ "id": "ZZW-IEWZ5syU"
1176
+ },
1177
+ "outputs": [],
1178
+ "source": [
1179
+ "import pandas as pd\n",
1180
+ "data = pd.read_csv(\"predictions.csv\")"
1181
+ ]
1182
+ },
1183
+ {
1184
+ "cell_type": "code",
1185
+ "execution_count": null,
1186
+ "metadata": {
1187
+ "colab": {
1188
+ "base_uri": "https://localhost:8080/",
1189
+ "height": 424
1190
+ },
1191
+ "id": "2sOkc_0vmDlB",
1192
+ "outputId": "750d06b5-fee2-4eb8-d7e6-a7043cd0c15a"
1193
+ },
1194
+ "outputs": [],
1195
+ "source": [
1196
+ "data"
1197
+ ]
1198
+ },
1199
+ {
1200
+ "cell_type": "code",
1201
+ "execution_count": null,
1202
+ "metadata": {
1203
+ "colab": {
1204
+ "base_uri": "https://localhost:8080/",
1205
+ "height": 142
1206
+ },
1207
+ "id": "AkG1vCpZ_vjG",
1208
+ "outputId": "d64b794c-d173-4871-80fc-93b8211ebedc"
1209
+ },
1210
+ "outputs": [],
1211
+ "source": [
1212
+ "# We also want to plot the prdiction table to wandb\n",
1213
+ "wandb.init(project=\"CS6910_Assignment_3\")"
1214
+ ]
1215
+ },
1216
+ {
1217
+ "cell_type": "code",
1218
+ "execution_count": null,
1219
+ "metadata": {
1220
+ "id": "MmKDX6V5_kGu"
1221
+ },
1222
+ "outputs": [],
1223
+ "source": [
1224
+ "table = wandb.Table(dataframe=data)\n",
1225
+ "wandb.log({\"data\": table})"
1226
+ ]
1227
+ },
1228
+ {
1229
+ "cell_type": "markdown",
1230
+ "metadata": {
1231
+ "id": "FYMa5jTQRUaB"
1232
+ },
1233
+ "source": [
1234
+ "## Plotting the confusion matrix in wandB"
1235
+ ]
1236
+ },
1237
+ {
1238
+ "cell_type": "code",
1239
+ "execution_count": null,
1240
+ "metadata": {
1241
+ "id": "YBaJZCIBRAGZ"
1242
+ },
1243
+ "outputs": [],
1244
+ "source": [
1245
+ "import numpy as np\n",
1246
+ "CM = np.zeros((opLang.n_chars, ipLang.n_chars))\n",
1247
+ "\n",
1248
+ "for i in range(len(testData[1])):\n",
1249
+ " for j in range(testData[1].shape[1]):\n",
1250
+ " pred = int(pred[i][j])\n",
1251
+ " targ = int(testData[1][i][j])\n",
1252
+ " CM[pred][targ] += 1\n",
1253
+ "\n",
1254
+ "classes =[]\n",
1255
+ "\n",
1256
+ "for i in range(len(CM)):\n",
1257
+ " classes.append(opLang.index2char[i])\n",
1258
+ "\n",
1259
+ "percentages = 100 * (CM / np.sum(CM))\n",
1260
+ "\n",
1261
+ "# Define the text for each cell\n",
1262
+ "cell_text = []\n",
1263
+ "for i in range(len(classes)):\n",
1264
+ " row_text = []\n",
1265
+ " for j in range(len(classes)):\n",
1266
+ "\n",
1267
+ " txt = \"Total \"+f'{CM[i, j]}Per. ({percentages[i, j]:.3f})'\n",
1268
+ " if(i==j):\n",
1269
+ " txt =\"Correcty Predicted \" +classes[i]+\"\"+txt\n",
1270
+ " if(i!=j):\n",
1271
+ " txt =\"Predicted \" +classes[j]+\" For \"+classes[i]+\"\"+txt\n",
1272
+ " row_text.append(txt)\n",
1273
+ " cell_text.append(row_text)\n",
1274
+ "\n",
1275
+ "import plotly.graph_objs as go\n",
1276
+ "\n",
1277
+ "# Define the trace\n",
1278
+ "trace = go.Heatmap(z=percentages,\n",
1279
+ " x=classes,\n",
1280
+ " y=classes,\n",
1281
+ " colorscale='Blues',\n",
1282
+ " colorbar=dict(title='Percentage'),\n",
1283
+ " hovertemplate='%{text}%',\n",
1284
+ " text=cell_text,\n",
1285
+ " )\n",
1286
+ "\n",
1287
+ "# Define the layout\n",
1288
+ "layout = go.Layout(title='Confusion Matrix',\n",
1289
+ " xaxis=dict(title='Predicted Character'),\n",
1290
+ " yaxis=dict(title='True Character'),\n",
1291
+ " )\n",
1292
+ "\n",
1293
+ "# Plot the figure\n",
1294
+ "fig = go.Figure(data=[trace], layout=layout)\n",
1295
+ "wandb.log({'confusion_matrix': (fig)})"
1296
+ ]
1297
+ },
1298
+ {
1299
+ "cell_type": "markdown",
1300
+ "metadata": {
1301
+ "id": "zfuv5FoA1wt2"
1302
+ },
1303
+ "source": [
1304
+ "# Runnning sweeps for models with Attention\n"
1305
+ ]
1306
+ },
1307
+ {
1308
+ "cell_type": "markdown",
1309
+ "metadata": {
1310
+ "id": "tsHS0PkNGHdV"
1311
+ },
1312
+ "source": [
1313
+ "## Sweep Config"
1314
+ ]
1315
+ },
1316
+ {
1317
+ "cell_type": "code",
1318
+ "execution_count": null,
1319
+ "metadata": {
1320
+ "id": "HwCn-Ci5xkTb"
1321
+ },
1322
+ "outputs": [],
1323
+ "source": [
1324
+ "sweep_config = {\n",
1325
+ " 'name': 'sweepDL', \n",
1326
+ " 'method': 'bayes',\n",
1327
+ " 'metric': {\n",
1328
+ " 'name': 'val_accuracy',\n",
1329
+ " 'goal': 'maximize'\n",
1330
+ " },\n",
1331
+ " 'parameters': {\n",
1332
+ " \n",
1333
+ " 'learn_rate': {\n",
1334
+ " 'values': [0.01, 0.001, 0.001]\n",
1335
+ " },\n",
1336
+ " 'embedding_size': {\n",
1337
+ " 'values': [32, 64, 128, 256, 512, 1024]\n",
1338
+ " },\n",
1339
+ " 'batch_size':{\n",
1340
+ " 'values':[16, 32, 64, 128, 256]\n",
1341
+ " },\n",
1342
+ " 'hidden_size':{\n",
1343
+ " 'values':[32, 64, 128, 256, 512, 1024]\n",
1344
+ " },\n",
1345
+ " 'teach_ratio':{\n",
1346
+ " 'values':[0.4, 0.5, 0.6]\n",
1347
+ " },\n",
1348
+ " 'dropout':{\n",
1349
+ " 'values':[0, 0.2, 0.4]\n",
1350
+ " },\n",
1351
+ " 'cell_type':{\n",
1352
+ " 'values':[\"RNN\", \"LSTM\", \"GRU\"]\n",
1353
+ " },\n",
1354
+ " 'bidirectional':{\n",
1355
+ " 'values' : [\"Yes\",\"No\"]\n",
1356
+ " },\n",
1357
+ " 'num_layers_decoder':{\n",
1358
+ " 'values': [1,2, 3, 4]\n",
1359
+ " },\n",
1360
+ " 'num_layers_encoder':{\n",
1361
+ " 'values': [1,2,3,4]\n",
1362
+ " },\n",
1363
+ " 'epochs':{\n",
1364
+ " 'values': [10, 15, 20, 25, 30]\n",
1365
+ " },\n",
1366
+ " 'attention':{\n",
1367
+ " 'values': [\"Yes\"]\n",
1368
+ " }\n",
1369
+ " \n",
1370
+ " }\n",
1371
+ "}\n",
1372
+ "config_defaults={\n",
1373
+ " 'learn_rate' : 0.001,\n",
1374
+ " 'embedding_size': 32,\n",
1375
+ " 'batch_size': 64,\n",
1376
+ " 'hidden_size' : 1024,\n",
1377
+ " 'num_layers_encoder': 1,\n",
1378
+ " 'num_layers_decoder': 1,\n",
1379
+ " 'bidirectional': 'Yes',\n",
1380
+ " 'cell_type': \"LSTM\",\n",
1381
+ " 'teach_ratio': 0.5,\n",
1382
+ " 'dropout': 0.4,\n",
1383
+ " 'epochs': 20,\n",
1384
+ " 'attention': \"Yes\"\n",
1385
+ "}"
1386
+ ]
1387
+ },
1388
+ {
1389
+ "cell_type": "code",
1390
+ "execution_count": null,
1391
+ "metadata": {
1392
+ "id": "3ADMwinqaQVF"
1393
+ },
1394
+ "outputs": [],
1395
+ "source": [
1396
+ "sweep_id=wandb.sweep(sweep_config, project=\"CS6910_Assignment_3\")\n",
1397
+ "wandb.agent(sweep_id,function=train)\n",
1398
+ "# wandb.agent(sweep_id= \"xiyggu44\",function=train, project=\"CS6910_Assignment_3\")"
1399
+ ]
1400
+ },
1401
+ {
1402
+ "cell_type": "markdown",
1403
+ "metadata": {
1404
+ "id": "W7CYNChRGuGK"
1405
+ },
1406
+ "source": [
1407
+ "# Testing the Best Model(with Attention) on Test Data \n",
1408
+ "Set default hyperparameters to the best hyperparameters got from sweeps Hyperparamer tuning"
1409
+ ]
1410
+ },
1411
+ {
1412
+ "cell_type": "code",
1413
+ "execution_count": null,
1414
+ "metadata": {
1415
+ "id": "C9MUrsXu_Rr4"
1416
+ },
1417
+ "outputs": [],
1418
+ "source": [
1419
+ "config_defaults={\n",
1420
+ " 'learn_rate' : 0.001,\n",
1421
+ " 'embedding_size': 32,\n",
1422
+ " 'batch_size': 64,\n",
1423
+ " 'hidden_size' : 1024,\n",
1424
+ " 'num_layers_encoder': 1,\n",
1425
+ " 'num_layers_decoder': 1,\n",
1426
+ " 'bidirectional': 'Yes',\n",
1427
+ " 'cell_type': \"LSTM\",\n",
1428
+ " 'teach_ratio': 0.5,\n",
1429
+ " 'dropout': 0.4,\n",
1430
+ " 'epochs': 20,\n",
1431
+ " 'attention': \"Yes\"\n",
1432
+ "}"
1433
+ ]
1434
+ },
1435
+ {
1436
+ "cell_type": "code",
1437
+ "execution_count": null,
1438
+ "metadata": {
1439
+ "id": "u7XAB4Q5Hpxj"
1440
+ },
1441
+ "outputs": [],
1442
+ "source": [
1443
+ "pred, atten_weights = train(sweeps = False, test = True)"
1444
+ ]
1445
+ },
1446
+ {
1447
+ "cell_type": "markdown",
1448
+ "metadata": {
1449
+ "id": "fld21YRZdRdG"
1450
+ },
1451
+ "source": [
1452
+ "# Saving the predictions by Vanilla model in csv file"
1453
+ ]
1454
+ },
1455
+ {
1456
+ "cell_type": "code",
1457
+ "execution_count": null,
1458
+ "metadata": {
1459
+ "colab": {
1460
+ "base_uri": "https://localhost:8080/"
1461
+ },
1462
+ "id": "BpDQ1mrydYWg",
1463
+ "outputId": "8784a3aa-315e-476f-cced-c38ebb8434b3"
1464
+ },
1465
+ "outputs": [
1466
+ {
1467
+ "name": "stdout",
1468
+ "output_type": "stream",
1469
+ "text": [
1470
+ "pred shape torch.Size([4096, 20]), input shape torch.Size([4096, 26]), target shape torch.Size([4096, 20])\n"
1471
+ ]
1472
+ }
1473
+ ],
1474
+ "source": [
1475
+ "# save the predictions\n",
1476
+ "dataframe = translate_prediction(ipLang.index2char, testData[:][0], opLang.index2char, pred, testData[:][1])\n",
1477
+ "dataframe.to_csv(\"predictions.csv\")"
1478
+ ]
1479
+ },
1480
+ {
1481
+ "cell_type": "code",
1482
+ "execution_count": null,
1483
+ "metadata": {
1484
+ "id": "PKMYPZdtdbDh"
1485
+ },
1486
+ "outputs": [],
1487
+ "source": [
1488
+ "import pandas as pd\n",
1489
+ "data = pd.read_csv(\"predictions.csv\")"
1490
+ ]
1491
+ },
1492
+ {
1493
+ "cell_type": "code",
1494
+ "execution_count": null,
1495
+ "metadata": {
1496
+ "colab": {
1497
+ "base_uri": "https://localhost:8080/",
1498
+ "height": 142
1499
+ },
1500
+ "id": "8gCL1rXCdgYp",
1501
+ "outputId": "d64b794c-d173-4871-80fc-93b8211ebedc"
1502
+ },
1503
+ "outputs": [],
1504
+ "source": [
1505
+ "# We also want to plot the prdiction table to wandb\n",
1506
+ "wandb.init(project=\"CS6910_Assignment_3\")"
1507
+ ]
1508
+ },
1509
+ {
1510
+ "cell_type": "code",
1511
+ "execution_count": null,
1512
+ "metadata": {
1513
+ "id": "N1r2ownhdjbz"
1514
+ },
1515
+ "outputs": [],
1516
+ "source": [
1517
+ "table = wandb.Table(dataframe=data)\n",
1518
+ "wandb.log({\"data\": table})"
1519
+ ]
1520
+ },
1521
+ {
1522
+ "cell_type": "markdown",
1523
+ "metadata": {
1524
+ "id": "LDP4KvWdFnIL"
1525
+ },
1526
+ "source": [
1527
+ "# Plotting the Attention HeatMaps"
1528
+ ]
1529
+ },
1530
+ {
1531
+ "cell_type": "code",
1532
+ "execution_count": null,
1533
+ "metadata": {
1534
+ "colab": {
1535
+ "base_uri": "https://localhost:8080/",
1536
+ "height": 1000,
1537
+ "referenced_widgets": [
1538
+ "1b0c5a6e21a349cba57322f850ad9f48",
1539
+ "3aa935a6db14483d8aaada58a84a3e47",
1540
+ "eabcea7a8bbf42f6aaa3995c0dece721",
1541
+ "b3b7711edb5542e08c53c4f37da10203",
1542
+ "39a8a3a9b6f1495ea17fd1b3d86b67c0",
1543
+ "18a8e2e817b947f9aad87b1ccaf96ea6",
1544
+ "da62d6e5ad0a462b98e1591d39038e1e",
1545
+ "9b5bb4f7f4a846c28ab967b64107726e"
1546
+ ]
1547
+ },
1548
+ "id": "4WfJEdcgFmiI",
1549
+ "outputId": "ff266529-4345-4cdc-9860-11914b099052"
1550
+ },
1551
+ "outputs": [],
1552
+ "source": [
1553
+ "import matplotlib.pyplot as plt\n",
1554
+ "import numpy as np\n",
1555
+ "from matplotlib.font_manager import FontProperties\n",
1556
+ "tel_font = FontProperties(fname = 'TiroDevanagariHindi-Regular.ttf')\n",
1557
+ "# Assuming you have attention_weights of shape (batch_size, output_sequence_length, batch_size, input_sequence_length)\n",
1558
+ "# and prediction_matrix of shape (batch_size, output_sequence_length)\n",
1559
+ "# and input_matrix of shape (batch_size, input_sequence_length)\n",
1560
+ "\n",
1561
+ "# Define the grid dimensions\n",
1562
+ "rows = int(np.ceil(np.sqrt(12)))\n",
1563
+ "cols = int(np.ceil(12 / rows))\n",
1564
+ "\n",
1565
+ "# Create a figure and subplots\n",
1566
+ "fig, axes = plt.subplots(rows, cols, figsize=(9, 9))\n",
1567
+ "\n",
1568
+ "for i, ax in enumerate(axes.flatten()):\n",
1569
+ " if i < 12:\n",
1570
+ " prediction = [opLang.index2char[j.item()] for j in pred[i+1]]\n",
1571
+ " \n",
1572
+ " pred_word=\"\"\n",
1573
+ " input_word=\"\"\n",
1574
+ "\n",
1575
+ " for j in range(len(prediction)):\n",
1576
+ " # Ignore padding\n",
1577
+ " if(prediction[j] != '#'):\n",
1578
+ " pred_word += prediction[j]\n",
1579
+ " else : \n",
1580
+ " break\n",
1581
+ " input_seq = [ipLang.index2char[j.item()] for j in testData[i][0]]\n",
1582
+ " \n",
1583
+ " for j in range(len(input_seq)):\n",
1584
+ " if(input_seq[j] != '#'):\n",
1585
+ " input_word += input_seq[j]\n",
1586
+ " else : \n",
1587
+ " break\n",
1588
+ " attn_weights = atten_weights[i, :len(pred_word), :len(input_word)].detach().cpu().numpy()\n",
1589
+ " ax.imshow(attn_weights.T, cmap='hot', interpolation='nearest')\n",
1590
+ " ax.xaxis.set_label_position('top')\n",
1591
+ " ax.set_title(f'Example {i+1}')\n",
1592
+ " ax.set_xlabel('Output predicted')\n",
1593
+ " ax.set_ylabel('Input word')\n",
1594
+ " ax.set_xticks(np.arange(len(pred_word)))\n",
1595
+ " ax.set_xticklabels(pred_word, rotation = 90, fontproperties = tel_font,fontdict={'fontsize':8})\n",
1596
+ " ax.xaxis.tick_top()\n",
1597
+ "\n",
1598
+ " ax.set_yticks(np.arange(len(input_word)))\n",
1599
+ " ax.set_yticklabels(input_word, rotation=90)\n",
1600
+ " \n",
1601
+ " \n",
1602
+ "\n",
1603
+ "# Adjust the spacing between subplots\n",
1604
+ "plt.tight_layout()\n",
1605
+ "\n",
1606
+ "# Show the plot\n",
1607
+ "plt.show()\n",
1608
+ "wandb.init(project='CS6910_Assignment_3')\n",
1609
+ "\n",
1610
+ "# Convert the matplotlib figure to an image\n",
1611
+ "fig.canvas.draw()\n",
1612
+ "image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')\n",
1613
+ "image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))\n",
1614
+ "\n",
1615
+ "# Log the image in wandb\n",
1616
+ "wandb.log({\"attention_heatmaps\": [wandb.Image(image)]})"
1617
+ ]
1618
+ },
1619
+ {
1620
+ "cell_type": "code",
1621
+ "execution_count": null,
1622
+ "metadata": {
1623
+ "id": "FnHR_oql6-S4"
1624
+ },
1625
+ "outputs": [],
1626
+ "source": []
1627
+ }
1628
+ ],
1629
+ "metadata": {
1630
+ "accelerator": "GPU",
1631
+ "colab": {
1632
+ "collapsed_sections": [
1633
+ "hRdpoWePeYHn",
1634
+ "44xIRolL_T_d",
1635
+ "Q1TioafYgICa",
1636
+ "svxssm9Havhb",
1637
+ "J56aq1J6a07q",
1638
+ "5JcQdylzI_Fc",
1639
+ "658W9RARGEUf",
1640
+ "q7fAgs5uQni_",
1641
+ "n4rGh7vuQqaa",
1642
+ "0SsnRWlgQmCI",
1643
+ "nvyRJWUUbR2f",
1644
+ "8ETW0BG_Pa24",
1645
+ "MQPGy32rnD3V",
1646
+ "z_aYZvDD1OHU",
1647
+ "pKvBd5mKf0Hf",
1648
+ "FYMa5jTQRUaB",
1649
+ "zfuv5FoA1wt2",
1650
+ "W7CYNChRGuGK"
1651
+ ],
1652
+ "gpuType": "T4",
1653
+ "include_colab_link": true,
1654
+ "provenance": []
1655
+ },
1656
+ "gpuClass": "standard",
1657
+ "kernelspec": {
1658
+ "display_name": "Python 3",
1659
+ "name": "python3"
1660
+ },
1661
+ "language_info": {
1662
+ "name": "python"
1663
+ },
1664
+ "widgets": {
1665
+ "application/vnd.jupyter.widget-state+json": {
1666
+ "18a8e2e817b947f9aad87b1ccaf96ea6": {
1667
+ "model_module": "@jupyter-widgets/controls",
1668
+ "model_module_version": "1.5.0",
1669
+ "model_name": "DescriptionStyleModel",
1670
+ "state": {
1671
+ "_model_module": "@jupyter-widgets/controls",
1672
+ "_model_module_version": "1.5.0",
1673
+ "_model_name": "DescriptionStyleModel",
1674
+ "_view_count": null,
1675
+ "_view_module": "@jupyter-widgets/base",
1676
+ "_view_module_version": "1.2.0",
1677
+ "_view_name": "StyleView",
1678
+ "description_width": ""
1679
+ }
1680
+ },
1681
+ "1b0c5a6e21a349cba57322f850ad9f48": {
1682
+ "model_module": "@jupyter-widgets/controls",
1683
+ "model_module_version": "1.5.0",
1684
+ "model_name": "VBoxModel",
1685
+ "state": {
1686
+ "_dom_classes": [],
1687
+ "_model_module": "@jupyter-widgets/controls",
1688
+ "_model_module_version": "1.5.0",
1689
+ "_model_name": "VBoxModel",
1690
+ "_view_count": null,
1691
+ "_view_module": "@jupyter-widgets/controls",
1692
+ "_view_module_version": "1.5.0",
1693
+ "_view_name": "VBoxView",
1694
+ "box_style": "",
1695
+ "children": [
1696
+ "IPY_MODEL_3aa935a6db14483d8aaada58a84a3e47",
1697
+ "IPY_MODEL_eabcea7a8bbf42f6aaa3995c0dece721"
1698
+ ],
1699
+ "layout": "IPY_MODEL_b3b7711edb5542e08c53c4f37da10203"
1700
+ }
1701
+ },
1702
+ "39a8a3a9b6f1495ea17fd1b3d86b67c0": {
1703
+ "model_module": "@jupyter-widgets/base",
1704
+ "model_module_version": "1.2.0",
1705
+ "model_name": "LayoutModel",
1706
+ "state": {
1707
+ "_model_module": "@jupyter-widgets/base",
1708
+ "_model_module_version": "1.2.0",
1709
+ "_model_name": "LayoutModel",
1710
+ "_view_count": null,
1711
+ "_view_module": "@jupyter-widgets/base",
1712
+ "_view_module_version": "1.2.0",
1713
+ "_view_name": "LayoutView",
1714
+ "align_content": null,
1715
+ "align_items": null,
1716
+ "align_self": null,
1717
+ "border": null,
1718
+ "bottom": null,
1719
+ "display": null,
1720
+ "flex": null,
1721
+ "flex_flow": null,
1722
+ "grid_area": null,
1723
+ "grid_auto_columns": null,
1724
+ "grid_auto_flow": null,
1725
+ "grid_auto_rows": null,
1726
+ "grid_column": null,
1727
+ "grid_gap": null,
1728
+ "grid_row": null,
1729
+ "grid_template_areas": null,
1730
+ "grid_template_columns": null,
1731
+ "grid_template_rows": null,
1732
+ "height": null,
1733
+ "justify_content": null,
1734
+ "justify_items": null,
1735
+ "left": null,
1736
+ "margin": null,
1737
+ "max_height": null,
1738
+ "max_width": null,
1739
+ "min_height": null,
1740
+ "min_width": null,
1741
+ "object_fit": null,
1742
+ "object_position": null,
1743
+ "order": null,
1744
+ "overflow": null,
1745
+ "overflow_x": null,
1746
+ "overflow_y": null,
1747
+ "padding": null,
1748
+ "right": null,
1749
+ "top": null,
1750
+ "visibility": null,
1751
+ "width": null
1752
+ }
1753
+ },
1754
+ "3aa935a6db14483d8aaada58a84a3e47": {
1755
+ "model_module": "@jupyter-widgets/controls",
1756
+ "model_module_version": "1.5.0",
1757
+ "model_name": "LabelModel",
1758
+ "state": {
1759
+ "_dom_classes": [],
1760
+ "_model_module": "@jupyter-widgets/controls",
1761
+ "_model_module_version": "1.5.0",
1762
+ "_model_name": "LabelModel",
1763
+ "_view_count": null,
1764
+ "_view_module": "@jupyter-widgets/controls",
1765
+ "_view_module_version": "1.5.0",
1766
+ "_view_name": "LabelView",
1767
+ "description": "",
1768
+ "description_tooltip": null,
1769
+ "layout": "IPY_MODEL_39a8a3a9b6f1495ea17fd1b3d86b67c0",
1770
+ "placeholder": "​",
1771
+ "style": "IPY_MODEL_18a8e2e817b947f9aad87b1ccaf96ea6",
1772
+ "value": "0.071 MB of 0.071 MB uploaded (0.000 MB deduped)\r"
1773
+ }
1774
+ },
1775
+ "9b5bb4f7f4a846c28ab967b64107726e": {
1776
+ "model_module": "@jupyter-widgets/controls",
1777
+ "model_module_version": "1.5.0",
1778
+ "model_name": "ProgressStyleModel",
1779
+ "state": {
1780
+ "_model_module": "@jupyter-widgets/controls",
1781
+ "_model_module_version": "1.5.0",
1782
+ "_model_name": "ProgressStyleModel",
1783
+ "_view_count": null,
1784
+ "_view_module": "@jupyter-widgets/base",
1785
+ "_view_module_version": "1.2.0",
1786
+ "_view_name": "StyleView",
1787
+ "bar_color": null,
1788
+ "description_width": ""
1789
+ }
1790
+ },
1791
+ "b3b7711edb5542e08c53c4f37da10203": {
1792
+ "model_module": "@jupyter-widgets/base",
1793
+ "model_module_version": "1.2.0",
1794
+ "model_name": "LayoutModel",
1795
+ "state": {
1796
+ "_model_module": "@jupyter-widgets/base",
1797
+ "_model_module_version": "1.2.0",
1798
+ "_model_name": "LayoutModel",
1799
+ "_view_count": null,
1800
+ "_view_module": "@jupyter-widgets/base",
1801
+ "_view_module_version": "1.2.0",
1802
+ "_view_name": "LayoutView",
1803
+ "align_content": null,
1804
+ "align_items": null,
1805
+ "align_self": null,
1806
+ "border": null,
1807
+ "bottom": null,
1808
+ "display": null,
1809
+ "flex": null,
1810
+ "flex_flow": null,
1811
+ "grid_area": null,
1812
+ "grid_auto_columns": null,
1813
+ "grid_auto_flow": null,
1814
+ "grid_auto_rows": null,
1815
+ "grid_column": null,
1816
+ "grid_gap": null,
1817
+ "grid_row": null,
1818
+ "grid_template_areas": null,
1819
+ "grid_template_columns": null,
1820
+ "grid_template_rows": null,
1821
+ "height": null,
1822
+ "justify_content": null,
1823
+ "justify_items": null,
1824
+ "left": null,
1825
+ "margin": null,
1826
+ "max_height": null,
1827
+ "max_width": null,
1828
+ "min_height": null,
1829
+ "min_width": null,
1830
+ "object_fit": null,
1831
+ "object_position": null,
1832
+ "order": null,
1833
+ "overflow": null,
1834
+ "overflow_x": null,
1835
+ "overflow_y": null,
1836
+ "padding": null,
1837
+ "right": null,
1838
+ "top": null,
1839
+ "visibility": null,
1840
+ "width": null
1841
+ }
1842
+ },
1843
+ "da62d6e5ad0a462b98e1591d39038e1e": {
1844
+ "model_module": "@jupyter-widgets/base",
1845
+ "model_module_version": "1.2.0",
1846
+ "model_name": "LayoutModel",
1847
+ "state": {
1848
+ "_model_module": "@jupyter-widgets/base",
1849
+ "_model_module_version": "1.2.0",
1850
+ "_model_name": "LayoutModel",
1851
+ "_view_count": null,
1852
+ "_view_module": "@jupyter-widgets/base",
1853
+ "_view_module_version": "1.2.0",
1854
+ "_view_name": "LayoutView",
1855
+ "align_content": null,
1856
+ "align_items": null,
1857
+ "align_self": null,
1858
+ "border": null,
1859
+ "bottom": null,
1860
+ "display": null,
1861
+ "flex": null,
1862
+ "flex_flow": null,
1863
+ "grid_area": null,
1864
+ "grid_auto_columns": null,
1865
+ "grid_auto_flow": null,
1866
+ "grid_auto_rows": null,
1867
+ "grid_column": null,
1868
+ "grid_gap": null,
1869
+ "grid_row": null,
1870
+ "grid_template_areas": null,
1871
+ "grid_template_columns": null,
1872
+ "grid_template_rows": null,
1873
+ "height": null,
1874
+ "justify_content": null,
1875
+ "justify_items": null,
1876
+ "left": null,
1877
+ "margin": null,
1878
+ "max_height": null,
1879
+ "max_width": null,
1880
+ "min_height": null,
1881
+ "min_width": null,
1882
+ "object_fit": null,
1883
+ "object_position": null,
1884
+ "order": null,
1885
+ "overflow": null,
1886
+ "overflow_x": null,
1887
+ "overflow_y": null,
1888
+ "padding": null,
1889
+ "right": null,
1890
+ "top": null,
1891
+ "visibility": null,
1892
+ "width": null
1893
+ }
1894
+ },
1895
+ "eabcea7a8bbf42f6aaa3995c0dece721": {
1896
+ "model_module": "@jupyter-widgets/controls",
1897
+ "model_module_version": "1.5.0",
1898
+ "model_name": "FloatProgressModel",
1899
+ "state": {
1900
+ "_dom_classes": [],
1901
+ "_model_module": "@jupyter-widgets/controls",
1902
+ "_model_module_version": "1.5.0",
1903
+ "_model_name": "FloatProgressModel",
1904
+ "_view_count": null,
1905
+ "_view_module": "@jupyter-widgets/controls",
1906
+ "_view_module_version": "1.5.0",
1907
+ "_view_name": "ProgressView",
1908
+ "bar_style": "",
1909
+ "description": "",
1910
+ "description_tooltip": null,
1911
+ "layout": "IPY_MODEL_da62d6e5ad0a462b98e1591d39038e1e",
1912
+ "max": 1,
1913
+ "min": 0,
1914
+ "orientation": "horizontal",
1915
+ "style": "IPY_MODEL_9b5bb4f7f4a846c28ab967b64107726e",
1916
+ "value": 1
1917
+ }
1918
+ }
1919
+ }
1920
+ }
1921
+ },
1922
+ "nbformat": 4,
1923
+ "nbformat_minor": 0
1924
+ }
notebooks/transformers.ipynb ADDED
@@ -0,0 +1,1929 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "colab_type": "text",
7
+ "id": "view-in-github"
8
+ },
9
+ "source": [
10
+ "<a href=\"https://colab.research.google.com/github/pankajrawat9075/Language-Transliteration-Model/blob/main/transformers_encoder_decoder.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "metadata": {
16
+ "id": "hRdpoWePeYHn"
17
+ },
18
+ "source": [
19
+ "## Importing Libraries and models"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": null,
25
+ "metadata": {
26
+ "execution": {
27
+ "iopub.execute_input": "2024-04-06T12:27:53.981869Z",
28
+ "iopub.status.busy": "2024-04-06T12:27:53.981590Z",
29
+ "iopub.status.idle": "2024-04-06T12:28:06.958537Z",
30
+ "shell.execute_reply": "2024-04-06T12:28:06.957350Z",
31
+ "shell.execute_reply.started": "2024-04-06T12:27:53.981844Z"
32
+ },
33
+ "id": "0LBvFtYGCNgJ",
34
+ "trusted": true
35
+ },
36
+ "outputs": [],
37
+ "source": [
38
+ "%%capture\n",
39
+ "!pip install wandb"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": null,
45
+ "metadata": {
46
+ "execution": {
47
+ "iopub.execute_input": "2024-04-06T12:28:06.960754Z",
48
+ "iopub.status.busy": "2024-04-06T12:28:06.960461Z",
49
+ "iopub.status.idle": "2024-04-06T12:28:12.559713Z",
50
+ "shell.execute_reply": "2024-04-06T12:28:12.558903Z",
51
+ "shell.execute_reply.started": "2024-04-06T12:28:06.960728Z"
52
+ },
53
+ "id": "z4ZVrIumZcDt",
54
+ "trusted": true
55
+ },
56
+ "outputs": [],
57
+ "source": [
58
+ "from __future__ import unicode_literals, print_function, division\n",
59
+ "from io import open\n",
60
+ "import unicodedata\n",
61
+ "import string\n",
62
+ "import re\n",
63
+ "import wandb\n",
64
+ "import random\n",
65
+ "import pandas as pd\n",
66
+ "import torch\n",
67
+ "import time\n",
68
+ "import numpy as np\n",
69
+ "import torch.nn as nn\n",
70
+ "from torch import optim\n",
71
+ "import matplotlib.pyplot as plt\n",
72
+ "import torch.nn.functional as F\n",
73
+ "from torch.utils.data import TensorDataset, DataLoader\n",
74
+ "\n",
75
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
76
+ "torch.cuda.empty_cache()"
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "execution_count": null,
82
+ "metadata": {
83
+ "colab": {
84
+ "base_uri": "https://localhost:8080/"
85
+ },
86
+ "execution": {
87
+ "iopub.execute_input": "2024-04-06T12:28:12.561336Z",
88
+ "iopub.status.busy": "2024-04-06T12:28:12.560805Z",
89
+ "iopub.status.idle": "2024-04-06T12:28:12.571498Z",
90
+ "shell.execute_reply": "2024-04-06T12:28:12.570579Z",
91
+ "shell.execute_reply.started": "2024-04-06T12:28:12.561311Z"
92
+ },
93
+ "id": "qwL09v65CIse",
94
+ "outputId": "5ea72523-6a50-474c-b617-b77e16d72ef3",
95
+ "trusted": true
96
+ },
97
+ "outputs": [
98
+ {
99
+ "name": "stdout",
100
+ "output_type": "stream",
101
+ "text": [
102
+ "cuda\n"
103
+ ]
104
+ }
105
+ ],
106
+ "source": [
107
+ "print(device)"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "markdown",
112
+ "metadata": {
113
+ "id": "44xIRolL_T_d"
114
+ },
115
+ "source": [
116
+ "## Load Dataset"
117
+ ]
118
+ },
119
+ {
120
+ "cell_type": "code",
121
+ "execution_count": null,
122
+ "metadata": {
123
+ "execution": {
124
+ "iopub.execute_input": "2024-04-06T12:28:12.573774Z",
125
+ "iopub.status.busy": "2024-04-06T12:28:12.573504Z",
126
+ "iopub.status.idle": "2024-04-06T12:28:12.583678Z",
127
+ "shell.execute_reply": "2024-04-06T12:28:12.582875Z",
128
+ "shell.execute_reply.started": "2024-04-06T12:28:12.573751Z"
129
+ },
130
+ "id": "Y4zemXiyE6Fi",
131
+ "trusted": true
132
+ },
133
+ "outputs": [],
134
+ "source": [
135
+ "class Language:\n",
136
+ " def __init__(self, name):\n",
137
+ " self.name = name\n",
138
+ " self.char2index = {'#': 0, '$': 1, '^': 2} # '^': start of sequence, '$' : unknown char, '#' : padding\n",
139
+ " self.index2char = {0: '#', 1: '$', 2: '^'}\n",
140
+ " self.vocab_size = 3 # Count\n",
141
+ "\n",
142
+ " def addWord(self, word):\n",
143
+ " for char in word:\n",
144
+ " self.addChar(char)\n",
145
+ "\n",
146
+ " def addChar(self, char):\n",
147
+ " if char not in self.char2index:\n",
148
+ " self.char2index[char] = self.vocab_size\n",
149
+ " self.index2char[self.vocab_size] = char\n",
150
+ " self.vocab_size += 1\n",
151
+ "\n",
152
+ " def encode(self, s):\n",
153
+ " return [self.char2index[ch] for ch in s]\n",
154
+ "\n",
155
+ " def decode(self, l):\n",
156
+ " return ''.join([self.index2char[i] for i in l])\n",
157
+ "\n",
158
+ " def vocab(self):\n",
159
+ " return self.char2index.keys()\n"
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "execution_count": null,
165
+ "metadata": {
166
+ "execution": {
167
+ "iopub.execute_input": "2024-04-06T12:28:12.584802Z",
168
+ "iopub.status.busy": "2024-04-06T12:28:12.584565Z",
169
+ "iopub.status.idle": "2024-04-06T12:28:12.594791Z",
170
+ "shell.execute_reply": "2024-04-06T12:28:12.593973Z",
171
+ "shell.execute_reply.started": "2024-04-06T12:28:12.584781Z"
172
+ },
173
+ "id": "IDGaCO8DkYpc",
174
+ "trusted": true
175
+ },
176
+ "outputs": [],
177
+ "source": [
178
+ "input_shape = 0\n",
179
+ "def preprocess(data, input_lang, output_lang, s=''):\n",
180
+ "\n",
181
+ " unknown = input_lang.char2index['$']\n",
182
+ "\n",
183
+ " input_max_len = 27\n",
184
+ " output_max_len = max([len(o) for o in data[1]])\n",
185
+ "\n",
186
+ " n = len(data)\n",
187
+ " input = torch.zeros((n, input_max_len + 1), device = device)\n",
188
+ " output = torch.zeros((n, output_max_len + 2), device = device)\n",
189
+ "\n",
190
+ " for i in range(n):\n",
191
+ "\n",
192
+ " inp = data[0][i].ljust(input_max_len + 1, '#')\n",
193
+ " op = '^' + data[1][i] # add start symbol to output\n",
194
+ " op = op.ljust(output_max_len + 2, '#')\n",
195
+ "\n",
196
+ " for index, char in enumerate(inp):\n",
197
+ " if char in input_lang.char2index:\n",
198
+ " input[i][index] = input_lang.char2index[char]\n",
199
+ " else:\n",
200
+ " input[i][index] = unknown\n",
201
+ "\n",
202
+ " for index, char in enumerate(op):\n",
203
+ " if char in output_lang.char2index:\n",
204
+ " output[i][index] = output_lang.char2index[char]\n",
205
+ " else:\n",
206
+ " output[i][index] = unknown\n",
207
+ "\n",
208
+ " print(s, ' dataset')\n",
209
+ " print(input.shape)\n",
210
+ " print(output.shape)\n",
211
+ "\n",
212
+ " return TensorDataset(input.to(torch.int32), output.to(torch.int32))"
213
+ ]
214
+ },
215
+ {
216
+ "cell_type": "code",
217
+ "execution_count": null,
218
+ "metadata": {
219
+ "colab": {
220
+ "base_uri": "https://localhost:8080/"
221
+ },
222
+ "execution": {
223
+ "iopub.execute_input": "2024-04-06T12:28:12.596018Z",
224
+ "iopub.status.busy": "2024-04-06T12:28:12.595741Z",
225
+ "iopub.status.idle": "2024-04-06T12:29:16.322883Z",
226
+ "shell.execute_reply": "2024-04-06T12:29:16.321877Z",
227
+ "shell.execute_reply.started": "2024-04-06T12:28:12.595995Z"
228
+ },
229
+ "id": "PdS5OXKxfdCX",
230
+ "outputId": "283fb51a-9a4a-4fc5-bad1-ea66373b29b4",
231
+ "trusted": true
232
+ },
233
+ "outputs": [
234
+ {
235
+ "name": "stdout",
236
+ "output_type": "stream",
237
+ "text": [
238
+ "train dataset\n",
239
+ "torch.Size([51200, 28])\n",
240
+ "torch.Size([51200, 22])\n",
241
+ "validation dataset\n",
242
+ "torch.Size([4096, 28])\n",
243
+ "torch.Size([4096, 22])\n",
244
+ "test dataset\n",
245
+ "torch.Size([4096, 28])\n",
246
+ "torch.Size([4096, 22])\n"
247
+ ]
248
+ }
249
+ ],
250
+ "source": [
251
+ "def load_prepare_data(lang):\n",
252
+ "\n",
253
+ " train_df = pd.read_csv(f\"drive/MyDrive/aksharantar_sampled/{lang}/{lang}_train.csv\", header = None)\n",
254
+ " val_df = pd.read_csv(f\"drive/MyDrive/aksharantar_sampled/{lang}/{lang}_valid.csv\", header = None)\n",
255
+ " test_df = pd.read_csv(f\"drive/MyDrive/aksharantar_sampled/{lang}/{lang}_test.csv\", header = None)\n",
256
+ "\n",
257
+ " input_lang = Language('eng')\n",
258
+ " output_lang = Language(lang)\n",
259
+ "\n",
260
+ " # create vocablury\n",
261
+ " for i in range(len(train_df)):\n",
262
+ " input_lang.addWord(train_df[0][i]) # 'eng'\n",
263
+ " output_lang.addWord(train_df[1][i]) # 'hin'\n",
264
+ "\n",
265
+ " # encode the datasets\n",
266
+ " train_data = preprocess(train_df, input_lang, output_lang, 'train')\n",
267
+ " val_data = preprocess(val_df, input_lang, output_lang, 'validation')\n",
268
+ " test_data = preprocess(test_df, input_lang, output_lang, 'test')\n",
269
+ "\n",
270
+ " return train_data, val_data, test_data, input_lang, output_lang\n",
271
+ "\n",
272
+ "\n",
273
+ "train_data, val_data, test_data, input_lang, output_lang = load_prepare_data('hin')\n"
274
+ ]
275
+ },
276
+ {
277
+ "cell_type": "code",
278
+ "execution_count": null,
279
+ "metadata": {
280
+ "colab": {
281
+ "base_uri": "https://localhost:8080/"
282
+ },
283
+ "execution": {
284
+ "iopub.execute_input": "2024-04-06T12:29:16.324674Z",
285
+ "iopub.status.busy": "2024-04-06T12:29:16.324273Z",
286
+ "iopub.status.idle": "2024-04-06T12:29:16.334834Z",
287
+ "shell.execute_reply": "2024-04-06T12:29:16.333992Z",
288
+ "shell.execute_reply.started": "2024-04-06T12:29:16.324643Z"
289
+ },
290
+ "id": "nu-NTR6BDj8e",
291
+ "outputId": "bd3dba2a-092d-4846-a5fb-f703f119b56a",
292
+ "trusted": true
293
+ },
294
+ "outputs": [
295
+ {
296
+ "name": "stdout",
297
+ "output_type": "stream",
298
+ "text": [
299
+ "hankers#####################\n"
300
+ ]
301
+ },
302
+ {
303
+ "data": {
304
+ "text/plain": [
305
+ "'^हैंकर्स##############'"
306
+ ]
307
+ },
308
+ "execution_count": 7,
309
+ "metadata": {},
310
+ "output_type": "execute_result"
311
+ }
312
+ ],
313
+ "source": [
314
+ "print(input_lang.decode(train_data[23][0].tolist()))\n",
315
+ "output_lang.decode(train_data[23][1].tolist())"
316
+ ]
317
+ },
318
+ {
319
+ "cell_type": "code",
320
+ "execution_count": null,
321
+ "metadata": {
322
+ "colab": {
323
+ "base_uri": "https://localhost:8080/"
324
+ },
325
+ "execution": {
326
+ "iopub.execute_input": "2024-04-06T12:29:16.336734Z",
327
+ "iopub.status.busy": "2024-04-06T12:29:16.336128Z",
328
+ "iopub.status.idle": "2024-04-06T12:29:16.355166Z",
329
+ "shell.execute_reply": "2024-04-06T12:29:16.354327Z",
330
+ "shell.execute_reply.started": "2024-04-06T12:29:16.336702Z"
331
+ },
332
+ "id": "yJI8iU6dBSE0",
333
+ "outputId": "818815ee-503e-4dcd-b7a6-5f00a06b5ace",
334
+ "trusted": true
335
+ },
336
+ "outputs": [
337
+ {
338
+ "data": {
339
+ "text/plain": [
340
+ "tensor([ 2, 34, 36, 17, 15, 7, 5, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
341
+ " 0, 0, 0, 0], device='cuda:0', dtype=torch.int32)"
342
+ ]
343
+ },
344
+ "execution_count": 8,
345
+ "metadata": {},
346
+ "output_type": "execute_result"
347
+ }
348
+ ],
349
+ "source": [
350
+ "train_data[23][1]"
351
+ ]
352
+ },
353
+ {
354
+ "cell_type": "code",
355
+ "execution_count": null,
356
+ "metadata": {
357
+ "colab": {
358
+ "base_uri": "https://localhost:8080/"
359
+ },
360
+ "execution": {
361
+ "iopub.execute_input": "2024-04-06T12:29:16.356467Z",
362
+ "iopub.status.busy": "2024-04-06T12:29:16.356175Z",
363
+ "iopub.status.idle": "2024-04-06T12:29:19.315416Z",
364
+ "shell.execute_reply": "2024-04-06T12:29:19.314522Z",
365
+ "shell.execute_reply.started": "2024-04-06T12:29:16.356444Z"
366
+ },
367
+ "id": "SvmzS5Lt_Jnl",
368
+ "outputId": "1387d646-ea3c-4fbf-b44f-c071e2b07784",
369
+ "trusted": true
370
+ },
371
+ "outputs": [
372
+ {
373
+ "name": "stderr",
374
+ "output_type": "stream",
375
+ "text": [
376
+ "\u001b[34m\u001b[1mwandb\u001b[0m: W&B API key is configured. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n",
377
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m If you're specifying your api key in code, ensure this code is not shared publicly.\n",
378
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Consider setting the WANDB_API_KEY environment variable, or running `wandb login` from the command line.\n",
379
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n"
380
+ ]
381
+ },
382
+ {
383
+ "data": {
384
+ "text/plain": [
385
+ "True"
386
+ ]
387
+ },
388
+ "execution_count": 9,
389
+ "metadata": {},
390
+ "output_type": "execute_result"
391
+ }
392
+ ],
393
+ "source": [
394
+ "wandb.login(key =\"\")"
395
+ ]
396
+ },
397
+ {
398
+ "cell_type": "markdown",
399
+ "metadata": {
400
+ "id": "Q1TioafYgICa"
401
+ },
402
+ "source": [
403
+ "# seq2seq tranformer model"
404
+ ]
405
+ },
406
+ {
407
+ "cell_type": "markdown",
408
+ "metadata": {
409
+ "id": "K94_u35dCk7-"
410
+ },
411
+ "source": [
412
+ "### hyperparameter settings"
413
+ ]
414
+ },
415
+ {
416
+ "cell_type": "code",
417
+ "execution_count": null,
418
+ "metadata": {
419
+ "execution": {
420
+ "iopub.execute_input": "2024-04-06T12:29:19.318625Z",
421
+ "iopub.status.busy": "2024-04-06T12:29:19.318195Z",
422
+ "iopub.status.idle": "2024-04-06T12:29:19.324068Z",
423
+ "shell.execute_reply": "2024-04-06T12:29:19.323194Z",
424
+ "shell.execute_reply.started": "2024-04-06T12:29:19.318601Z"
425
+ },
426
+ "id": "PugX7KHvc65u",
427
+ "trusted": true
428
+ },
429
+ "outputs": [],
430
+ "source": [
431
+ "n_embd = 64\n",
432
+ "batch_size = 256\n",
433
+ "learning_rate = 1e-3\n",
434
+ "n_head = 4 # other options factors of 32 like 2, 8\n",
435
+ "n_layers = 6\n",
436
+ "dropout = 0.2\n",
437
+ "epochs = 50\n",
438
+ "\n",
439
+ "# encoder specific detail\n",
440
+ "input_vocab_size = input_lang.vocab_size\n",
441
+ "encoder_block_size = len(train_data[0][0])\n",
442
+ "\n",
443
+ "# decoder specific detail\n",
444
+ "output_vocab_size = output_lang.vocab_size\n",
445
+ "decoder_block_size = len(train_data[0][1])"
446
+ ]
447
+ },
448
+ {
449
+ "cell_type": "markdown",
450
+ "metadata": {
451
+ "id": "XdltQ7oJCq1j"
452
+ },
453
+ "source": [
454
+ "### Encoder model"
455
+ ]
456
+ },
457
+ {
458
+ "cell_type": "code",
459
+ "execution_count": null,
460
+ "metadata": {
461
+ "execution": {
462
+ "iopub.execute_input": "2024-04-06T12:29:19.325685Z",
463
+ "iopub.status.busy": "2024-04-06T12:29:19.325424Z",
464
+ "iopub.status.idle": "2024-04-06T12:29:19.351414Z",
465
+ "shell.execute_reply": "2024-04-06T12:29:19.350579Z",
466
+ "shell.execute_reply.started": "2024-04-06T12:29:19.325663Z"
467
+ },
468
+ "id": "uiluDiY7FAMU",
469
+ "trusted": true
470
+ },
471
+ "outputs": [],
472
+ "source": [
473
+ "class Head(nn.Module):\n",
474
+ " \"\"\" one self-attention head \"\"\"\n",
475
+ "\n",
476
+ " def __init__(self, n_embd, d_k, dropout, mask=0): # d_k is dimention of key , nomaly d_k = n_embd / 4\n",
477
+ " super().__init__()\n",
478
+ " self.mask = mask\n",
479
+ " self.key = nn.Linear(n_embd, d_k, bias=False, device=device)\n",
480
+ " self.query = nn.Linear(n_embd, d_k, bias=False, device=device)\n",
481
+ " self.value = nn.Linear(n_embd, d_k, bias=False, device=device)\n",
482
+ " if mask:\n",
483
+ " self.register_buffer('tril', torch.tril(torch.ones(encoder_block_size, encoder_block_size, device=device)))\n",
484
+ " self.dropout = nn.Dropout(dropout)\n",
485
+ "\n",
486
+ " def forward(self, x, encoder_output = None):\n",
487
+ " B,T,C = x.shape\n",
488
+ "\n",
489
+ " if encoder_output is not None:\n",
490
+ " k = self.key(encoder_output)\n",
491
+ " Be, Te, Ce = encoder_output.shape\n",
492
+ " else:\n",
493
+ " k = self.key(x) # (B,T,d_k)\n",
494
+ "\n",
495
+ " q = self.query(x) # (B,T,d_k)\n",
496
+ " # compute attention scores\n",
497
+ " wei = q @ k.transpose(-2, -1) * C**-0.5 # (B,T,T)\n",
498
+ "\n",
499
+ " if self.mask:\n",
500
+ " if encoder_output is not None:\n",
501
+ " wei = wei.masked_fill(self.tril[:T, :Te] == 0, float('-inf')) # (B,T,T)\n",
502
+ " else:\n",
503
+ " wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B,T,T)\n",
504
+ "\n",
505
+ " wei = F.softmax(wei, dim=-1)\n",
506
+ " wei = self.dropout(wei)\n",
507
+ " # perform weighted aggregation of values\n",
508
+ " if encoder_output is not None:\n",
509
+ " v = self.value(encoder_output)\n",
510
+ " else:\n",
511
+ " v = self.value(x)\n",
512
+ " out = wei @ v # (B,T,C)\n",
513
+ " return out\n",
514
+ "\n",
515
+ "class MultiHeadAttention(nn.Module):\n",
516
+ " \"\"\" multiple self attention heads in parallel \"\"\"\n",
517
+ "\n",
518
+ " def __init__(self, n_embd, num_head, d_k, dropout, mask=0):\n",
519
+ " super().__init__()\n",
520
+ " self.heads = nn.ModuleList([Head(n_embd, d_k, dropout, mask) for _ in range(num_head)])\n",
521
+ " self.proj = nn.Linear(n_embd, n_embd)\n",
522
+ " self.dropout = nn.Dropout(dropout)\n",
523
+ "\n",
524
+ " def forward(self, x, encoder_output=None):\n",
525
+ " out = torch.cat([h(x, encoder_output) for h in self.heads], dim=-1)\n",
526
+ " out = self.dropout(self.proj(out))\n",
527
+ " return out\n",
528
+ "\n",
529
+ "class FeedForward(nn.Module):\n",
530
+ " \"\"\" multiple self attention heads in parallel \"\"\"\n",
531
+ "\n",
532
+ " def __init__(self, n_embd, dropout):\n",
533
+ " super().__init__()\n",
534
+ " self.net = nn.Sequential(\n",
535
+ " nn.Linear(n_embd, 4 * n_embd),\n",
536
+ " nn.ReLU(),\n",
537
+ " nn.Linear(4 * n_embd, n_embd),\n",
538
+ " nn.Dropout(dropout)\n",
539
+ " )\n",
540
+ "\n",
541
+ " def forward(self, x):\n",
542
+ " return self.net(x)\n",
543
+ "\n",
544
+ "class encoderBlock(nn.Module):\n",
545
+ " \"\"\" Tranformer encoder block : communication followed by computation \"\"\"\n",
546
+ "\n",
547
+ " def __init__(self, n_embd, n_head, dropout):\n",
548
+ " super().__init__()\n",
549
+ " d_k = n_embd // n_head\n",
550
+ " self.sa = MultiHeadAttention(n_embd, n_head, d_k, dropout)\n",
551
+ " self.ffwd = FeedForward(n_embd, dropout)\n",
552
+ " self.ln1 = nn.LayerNorm(n_embd)\n",
553
+ " self.ln2 = nn.LayerNorm(n_embd)\n",
554
+ "\n",
555
+ " def forward(self, x, encoder_output=None):\n",
556
+ " x = x + self.sa(self.ln1(x), encoder_output)\n",
557
+ " x = x + self.ffwd(self.ln2(x))\n",
558
+ " return x\n",
559
+ "\n",
560
+ "class Encoder(nn.Module):\n",
561
+ "\n",
562
+ " def __init__(self, n_embd, n_head, n_layers, dropout):\n",
563
+ " super().__init__()\n",
564
+ "\n",
565
+ " self.token_embedding_table = nn.Embedding(input_vocab_size, n_embd) # n_embd: input embedding dimension\n",
566
+ " self.position_embedding_table = nn.Embedding(encoder_block_size, n_embd)\n",
567
+ " self.blocks = nn.Sequential(*[encoderBlock(n_embd, n_head, dropout) for _ in range(n_layers)])\n",
568
+ " self.ln_f = nn.LayerNorm(n_embd) # final layer norm\n",
569
+ "\n",
570
+ " def forward(self, idx):\n",
571
+ " B, T = idx.shape\n",
572
+ " tok_emb = self.token_embedding_table(idx) # (B,T,n_embd)\n",
573
+ " pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,n_embd)\n",
574
+ " x = tok_emb + pos_emb # (B,T,n_embd)\n",
575
+ " x = self.blocks(x) # apply one attention layer (B,T,C)\n",
576
+ " x = self.ln_f(x) # (B,T,C)\n",
577
+ " return x\n"
578
+ ]
579
+ },
580
+ {
581
+ "cell_type": "markdown",
582
+ "metadata": {
583
+ "id": "GgPU486JC8Mz"
584
+ },
585
+ "source": [
586
+ "### Decoder model"
587
+ ]
588
+ },
589
+ {
590
+ "cell_type": "code",
591
+ "execution_count": null,
592
+ "metadata": {
593
+ "execution": {
594
+ "iopub.execute_input": "2024-04-06T12:29:19.352896Z",
595
+ "iopub.status.busy": "2024-04-06T12:29:19.352571Z",
596
+ "iopub.status.idle": "2024-04-06T12:29:19.367829Z",
597
+ "shell.execute_reply": "2024-04-06T12:29:19.366971Z",
598
+ "shell.execute_reply.started": "2024-04-06T12:29:19.352872Z"
599
+ },
600
+ "id": "JteOV0CdC_bv",
601
+ "trusted": true
602
+ },
603
+ "outputs": [],
604
+ "source": [
605
+ "class decoderBlock(nn.Module):\n",
606
+ " \"\"\" Tranformer decoder block : self communication then cross communication followed by computation \"\"\"\n",
607
+ "\n",
608
+ " def __init__(self, n_embd, n_head, dropout):\n",
609
+ " super().__init__()\n",
610
+ " d_k = n_embd // n_head\n",
611
+ " self.sa = MultiHeadAttention(n_embd, n_head, d_k, dropout, mask = 1)\n",
612
+ " self.ca = MultiHeadAttention(n_embd, n_head, d_k, dropout, mask = 1)\n",
613
+ " self.ffwd = FeedForward(n_embd, dropout)\n",
614
+ " self.ln1 = nn.LayerNorm(n_embd, device=device)\n",
615
+ " self.ln2 = nn.LayerNorm(n_embd, device=device)\n",
616
+ " self.ln3 = nn.LayerNorm(n_embd, device=device)\n",
617
+ "\n",
618
+ " def forward(self, x_encoder_output):\n",
619
+ " x = x_encoder_output[0]\n",
620
+ " encoder_output = x_encoder_output[1]\n",
621
+ " x = x + self.sa(self.ln1(x))\n",
622
+ " x = x + self.ca(self.ln2(x), encoder_output)\n",
623
+ " x = x + self.ffwd(self.ln3(x))\n",
624
+ " return (x,encoder_output)\n",
625
+ "\n",
626
+ "class Decoder(nn.Module):\n",
627
+ "\n",
628
+ " def __init__(self, n_embd, n_head, n_layers, dropout):\n",
629
+ " super().__init__()\n",
630
+ "\n",
631
+ " self.token_embedding_table = nn.Embedding(output_vocab_size, n_embd) # n_embd: input embedding dimension\n",
632
+ " self.position_embedding_table = nn.Embedding(decoder_block_size, n_embd)\n",
633
+ " self.blocks = nn.Sequential(*[decoderBlock(n_embd, n_head=n_head, dropout=dropout) for _ in range(n_layers)])\n",
634
+ " self.ln_f = nn.LayerNorm(n_embd) # final layer norm\n",
635
+ " self.lm_head = nn.Linear(n_embd, output_vocab_size)\n",
636
+ "\n",
637
+ " def forward(self, idx, encoder_output, targets=None):\n",
638
+ " B, T = idx.shape\n",
639
+ "\n",
640
+ " tok_emb = self.token_embedding_table(idx) # (B,T,n_embd)\n",
641
+ " pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,n_embd)\n",
642
+ " x = tok_emb + pos_emb # (B,T,n_embd)\n",
643
+ "\n",
644
+ " x =self.blocks((x, encoder_output))\n",
645
+ " x = self.ln_f(x[0]) # (B,T,C)\n",
646
+ " logits = self.lm_head(x) # (B,T,output_vocab_size)\n",
647
+ "\n",
648
+ " if targets is None:\n",
649
+ " loss = None\n",
650
+ " else:\n",
651
+ " B, T, C = logits.shape\n",
652
+ " temp_logits = logits.view(B*T, C)\n",
653
+ " targets = targets.reshape(B*T)\n",
654
+ "\n",
655
+ " loss = F.cross_entropy(temp_logits, targets.long())\n",
656
+ "\n",
657
+ " # print(logits)\n",
658
+ " # out = torch.argmax(logits)\n",
659
+ "\n",
660
+ " return logits, loss\n",
661
+ "\n"
662
+ ]
663
+ },
664
+ {
665
+ "cell_type": "markdown",
666
+ "metadata": {
667
+ "id": "EBjmsIcklM8Y"
668
+ },
669
+ "source": [
670
+ "# Training Time"
671
+ ]
672
+ },
673
+ {
674
+ "cell_type": "markdown",
675
+ "metadata": {
676
+ "id": "lLfHEDk8FNfY"
677
+ },
678
+ "source": [
679
+ "## sweep config"
680
+ ]
681
+ },
682
+ {
683
+ "cell_type": "code",
684
+ "execution_count": null,
685
+ "metadata": {
686
+ "execution": {
687
+ "iopub.execute_input": "2024-04-04T14:54:15.308213Z",
688
+ "iopub.status.busy": "2024-04-04T14:54:15.307981Z",
689
+ "iopub.status.idle": "2024-04-04T14:54:15.319933Z",
690
+ "shell.execute_reply": "2024-04-04T14:54:15.319070Z",
691
+ "shell.execute_reply.started": "2024-04-04T14:54:15.308192Z"
692
+ },
693
+ "id": "nDcRZmb80msE",
694
+ "trusted": true
695
+ },
696
+ "outputs": [],
697
+ "source": [
698
+ "# Define sweep config\n",
699
+ "sweep_configuration = {\n",
700
+ " \"method\": \"bayes\",\n",
701
+ " \"name\": \"sweep\",\n",
702
+ " \"metric\": {\"goal\": \"maximize\", \"name\": \"val_acc\"},\n",
703
+ " \"parameters\": {\n",
704
+ " \"batch_size\": {\"values\": [64, 128, 256]},\n",
705
+ " \"epochs\": {\"values\": [20, 40, 50, 100]},\n",
706
+ " \"lr\": {\"max\": 0.1, \"min\": 0.0001},\n",
707
+ " \"n_embd\": {\"values\": [16, 32, 64]},\n",
708
+ " \"n_head\": {\"values\": [2, 4, 8]},\n",
709
+ " \"n_layers\": {\"values\": [4, 6, 8]},\n",
710
+ " \"dropout\": {\"values\": [0, .1, .2, .3]}\n",
711
+ " },\n",
712
+ "}\n",
713
+ "\n",
714
+ "sweep_id = wandb.sweep(sweep=sweep_configuration, project=\"Tranliteration-Tranformers\")"
715
+ ]
716
+ },
717
+ {
718
+ "cell_type": "code",
719
+ "execution_count": null,
720
+ "metadata": {
721
+ "execution": {
722
+ "iopub.execute_input": "2024-04-04T14:54:15.325199Z",
723
+ "iopub.status.busy": "2024-04-04T14:54:15.324615Z",
724
+ "iopub.status.idle": "2024-04-04T14:54:15.330172Z",
725
+ "shell.execute_reply": "2024-04-04T14:54:15.329301Z",
726
+ "shell.execute_reply.started": "2024-04-04T14:54:15.325168Z"
727
+ },
728
+ "id": "9CguGUG5_1NL",
729
+ "trusted": true
730
+ },
731
+ "outputs": [],
732
+ "source": [
733
+ "# wandb.sweep_cancel(sweep_id)\n",
734
+ "# wandb.finish()\n",
735
+ "# wandb.run.cancel()"
736
+ ]
737
+ },
738
+ {
739
+ "cell_type": "markdown",
740
+ "metadata": {
741
+ "id": "d5T58TQRECbZ"
742
+ },
743
+ "source": [
744
+ "## train function"
745
+ ]
746
+ },
747
+ {
748
+ "cell_type": "code",
749
+ "execution_count": null,
750
+ "metadata": {
751
+ "execution": {
752
+ "iopub.execute_input": "2024-04-04T14:54:15.331837Z",
753
+ "iopub.status.busy": "2024-04-04T14:54:15.331538Z",
754
+ "iopub.status.idle": "2024-04-04T14:54:15.351924Z",
755
+ "shell.execute_reply": "2024-04-04T14:54:15.351027Z",
756
+ "shell.execute_reply.started": "2024-04-04T14:54:15.331810Z"
757
+ },
758
+ "id": "3GWnCggNFLs3",
759
+ "trusted": true
760
+ },
761
+ "outputs": [],
762
+ "source": [
763
+ "def train():\n",
764
+ " run = wandb.init()\n",
765
+ "\n",
766
+ " n_embd = wandb.config.n_embd\n",
767
+ " n_head = wandb.config.n_head\n",
768
+ " n_layers = wandb.config.n_layers\n",
769
+ " dropout = wandb.config.dropout\n",
770
+ " epochs = wandb.config.epochs\n",
771
+ " batch_size = wandb.config.batch_size\n",
772
+ " learning_rate = wandb.config.lr\n",
773
+ "\n",
774
+ "\n",
775
+ " encoder = Encoder(n_embd, n_head, n_layers, dropout)\n",
776
+ " decoder = Decoder(n_embd, n_head, n_layers, dropout)\n",
777
+ " encoder.to(device)\n",
778
+ " decoder.to(device)\n",
779
+ "\n",
780
+ " train_losses, train_accuracies, val_losses, val_accuracies = [], [], [], []\n",
781
+ "\n",
782
+ " # print the number of parameters in the model\n",
783
+ " print(sum([p.numel() for p in encoder.parameters()] + [p.numel() for p in decoder.parameters()])/1e3, 'K model parameters')\n",
784
+ "\n",
785
+ " train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)\n",
786
+ " val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True)\n",
787
+ "\n",
788
+ " # create a PyTorch optimizer\n",
789
+ " encoder_optimizer = torch.optim.AdamW(encoder.parameters(), lr=learning_rate)\n",
790
+ " decoder_optimizer = torch.optim.AdamW(decoder.parameters(), lr=learning_rate)\n",
791
+ "\n",
792
+ "# print('Step | Training Loss | Validation Loss | Training Accuracy % | Validation Accuracy %')\n",
793
+ "\n",
794
+ " least_error = float('inf')\n",
795
+ " patience = 20 # The number of epochs without improvement to wait before stopping\n",
796
+ " no_improvement = 0\n",
797
+ "\n",
798
+ " for i in range(epochs):\n",
799
+ " running_loss = 0.0\n",
800
+ " train_correct = 0\n",
801
+ "\n",
802
+ " encoder.train()\n",
803
+ " decoder.train()\n",
804
+ "\n",
805
+ " for j,(train_x,train_y) in enumerate(train_loader):\n",
806
+ " train_x = train_x.to(device)\n",
807
+ " train_y = train_y.to(device)\n",
808
+ "\n",
809
+ " encoder_optimizer.zero_grad(set_to_none=True)\n",
810
+ " decoder_optimizer.zero_grad(set_to_none=True)\n",
811
+ "\n",
812
+ " encoder_output = encoder(train_x)\n",
813
+ " logits, loss = decoder(train_y[:, :-1], encoder_output, train_y[:, 1:])\n",
814
+ "\n",
815
+ " encoder_optimizer.zero_grad(set_to_none=True)\n",
816
+ " decoder_optimizer.zero_grad(set_to_none=True)\n",
817
+ " loss.backward()\n",
818
+ " encoder_optimizer.step()\n",
819
+ " decoder_optimizer.step()\n",
820
+ "\n",
821
+ " running_loss += loss\n",
822
+ " pred_decoder_output = torch.argmax(logits, dim=-1)\n",
823
+ " # print(pred_decoder_output, \" target: \", train_y[:, 1:])\n",
824
+ " train_correct += (pred_decoder_output == train_y[:, 1:]).sum().item()\n",
825
+ "\n",
826
+ "\n",
827
+ " ## validation code\n",
828
+ " running_loss_val, val_correct = 0, 0\n",
829
+ " encoder.eval()\n",
830
+ " decoder.eval()\n",
831
+ " for j,(val_x,val_y) in enumerate(val_loader):\n",
832
+ " val_x = val_x.to(device)\n",
833
+ " val_y = val_y.to(device)\n",
834
+ "\n",
835
+ " encoder_output = encoder(val_x)\n",
836
+ " logits, loss = decoder(val_y[:, :-1], encoder_output, val_y[:, 1:])\n",
837
+ "\n",
838
+ " running_loss_val += loss\n",
839
+ " pred_decoder_output = torch.argmax(logits, dim=-1)\n",
840
+ " val_correct += torch.sum(pred_decoder_output == val_y[:, 1:])\n",
841
+ "\n",
842
+ "\n",
843
+ " if running_loss_val < least_error:\n",
844
+ " least_error = running_loss_val\n",
845
+ " no_improvement = 0\n",
846
+ " else:\n",
847
+ " no_improvement += 1\n",
848
+ "\n",
849
+ " if no_improvement >= patience:\n",
850
+ " print(f\"Early stopping at epoch {i}\")\n",
851
+ " break\n",
852
+ "\n",
853
+ " wandb.log(\n",
854
+ " {\n",
855
+ " \"train_loss\": running_loss / len(train_data),\n",
856
+ " \"val_loss\": (running_loss_val/len(val_data)),\n",
857
+ " \"train_acc\": ((train_correct*100) / (len(train_data)* (decoder_block_size-1))),\n",
858
+ " \"val_acc\": ((val_correct*100)/(len(val_data)* (decoder_block_size-1))),\n",
859
+ " }\n",
860
+ " )"
861
+ ]
862
+ },
863
+ {
864
+ "cell_type": "markdown",
865
+ "metadata": {
866
+ "id": "CxzRR9cjEGDm"
867
+ },
868
+ "source": [
869
+ "## run sweep"
870
+ ]
871
+ },
872
+ {
873
+ "cell_type": "code",
874
+ "execution_count": null,
875
+ "metadata": {
876
+ "colab": {
877
+ "base_uri": "https://localhost:8080/",
878
+ "height": 295,
879
+ "referenced_widgets": [
880
+ ""
881
+ ]
882
+ },
883
+ "execution": {
884
+ "iopub.execute_input": "2024-04-04T14:54:15.353688Z",
885
+ "iopub.status.busy": "2024-04-04T14:54:15.353125Z"
886
+ },
887
+ "id": "u_QFbYe32t7r",
888
+ "outputId": "97153eab-b36f-454b-9fed-53ae0287aee1",
889
+ "trusted": true
890
+ },
891
+ "outputs": [
892
+ {
893
+ "name": "stderr",
894
+ "output_type": "stream",
895
+ "text": [
896
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Agent Starting Run: dcco6zur with config:\n",
897
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \tbatch_size: 64\n",
898
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \tdropout: 0\n",
899
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \tepochs: 50\n",
900
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \tlr: 0.0003\n",
901
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \tn_embd: 64\n",
902
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \tn_head: 4\n",
903
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \tn_layers: 6\n",
904
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mcs22m062\u001b[0m (\u001b[33miitmadras\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
905
+ ]
906
+ },
907
+ {
908
+ "data": {
909
+ "text/html": [
910
+ "wandb version 0.16.6 is available! To upgrade, please run:\n",
911
+ " $ pip install wandb --upgrade"
912
+ ],
913
+ "text/plain": [
914
+ "<IPython.core.display.HTML object>"
915
+ ]
916
+ },
917
+ "metadata": {},
918
+ "output_type": "display_data"
919
+ },
920
+ {
921
+ "data": {
922
+ "text/html": [
923
+ "Tracking run with wandb version 0.16.4"
924
+ ],
925
+ "text/plain": [
926
+ "<IPython.core.display.HTML object>"
927
+ ]
928
+ },
929
+ "metadata": {},
930
+ "output_type": "display_data"
931
+ },
932
+ {
933
+ "data": {
934
+ "text/html": [
935
+ "Run data is saved locally in <code>/kaggle/working/wandb/run-20240404_145417-dcco6zur</code>"
936
+ ],
937
+ "text/plain": [
938
+ "<IPython.core.display.HTML object>"
939
+ ]
940
+ },
941
+ "metadata": {},
942
+ "output_type": "display_data"
943
+ },
944
+ {
945
+ "data": {
946
+ "text/html": [
947
+ "Syncing run <strong><a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/dcco6zur' target=\"_blank\">eager-sweep-2</a></strong> to <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>Sweep page: <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161</a>"
948
+ ],
949
+ "text/plain": [
950
+ "<IPython.core.display.HTML object>"
951
+ ]
952
+ },
953
+ "metadata": {},
954
+ "output_type": "display_data"
955
+ },
956
+ {
957
+ "data": {
958
+ "text/html": [
959
+ " View project at <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers</a>"
960
+ ],
961
+ "text/plain": [
962
+ "<IPython.core.display.HTML object>"
963
+ ]
964
+ },
965
+ "metadata": {},
966
+ "output_type": "display_data"
967
+ },
968
+ {
969
+ "data": {
970
+ "text/html": [
971
+ " View sweep at <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161</a>"
972
+ ],
973
+ "text/plain": [
974
+ "<IPython.core.display.HTML object>"
975
+ ]
976
+ },
977
+ "metadata": {},
978
+ "output_type": "display_data"
979
+ },
980
+ {
981
+ "data": {
982
+ "text/html": [
983
+ " View run at <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/dcco6zur' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/dcco6zur</a>"
984
+ ],
985
+ "text/plain": [
986
+ "<IPython.core.display.HTML object>"
987
+ ]
988
+ },
989
+ "metadata": {},
990
+ "output_type": "display_data"
991
+ },
992
+ {
993
+ "name": "stdout",
994
+ "output_type": "stream",
995
+ "text": [
996
+ "710.915 K model parameters\n",
997
+ "Early stopping at epoch 32\n"
998
+ ]
999
+ },
1000
+ {
1001
+ "data": {
1002
+ "application/vnd.jupyter.widget-view+json": {
1003
+ "model_id": "",
1004
+ "version_major": 2,
1005
+ "version_minor": 0
1006
+ },
1007
+ "text/plain": [
1008
+ "VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\\r'), FloatProgress(value=1.0, max=1.0)))"
1009
+ ]
1010
+ },
1011
+ "metadata": {},
1012
+ "output_type": "display_data"
1013
+ },
1014
+ {
1015
+ "data": {
1016
+ "text/html": [
1017
+ "<style>\n",
1018
+ " table.wandb td:nth-child(1) { padding: 0 10px; text-align: left ; width: auto;} td:nth-child(2) {text-align: left ; width: 100%}\n",
1019
+ " .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n",
1020
+ " .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n",
1021
+ " </style>\n",
1022
+ "<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>train_acc</td><td>▁▅▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇███████████████</td></tr><tr><td>train_loss</td><td>█▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>val_acc</td><td>▁▅▆▆▇▇▇▇▇▇██████████████████████</td></tr><tr><td>val_loss</td><td>█▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>train_acc</td><td>97.8125</td></tr><tr><td>train_loss</td><td>0.00096</td></tr><tr><td>val_acc</td><td>95.29739</td></tr><tr><td>val_loss</td><td>0.00286</td></tr></table><br/></div></div>"
1023
+ ],
1024
+ "text/plain": [
1025
+ "<IPython.core.display.HTML object>"
1026
+ ]
1027
+ },
1028
+ "metadata": {},
1029
+ "output_type": "display_data"
1030
+ },
1031
+ {
1032
+ "data": {
1033
+ "text/html": [
1034
+ " View run <strong style=\"color:#cdcd00\">eager-sweep-2</strong> at: <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/dcco6zur' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/dcco6zur</a><br/>Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
1035
+ ],
1036
+ "text/plain": [
1037
+ "<IPython.core.display.HTML object>"
1038
+ ]
1039
+ },
1040
+ "metadata": {},
1041
+ "output_type": "display_data"
1042
+ },
1043
+ {
1044
+ "data": {
1045
+ "text/html": [
1046
+ "Find logs at: <code>./wandb/run-20240404_145417-dcco6zur/logs</code>"
1047
+ ],
1048
+ "text/plain": [
1049
+ "<IPython.core.display.HTML object>"
1050
+ ]
1051
+ },
1052
+ "metadata": {},
1053
+ "output_type": "display_data"
1054
+ },
1055
+ {
1056
+ "name": "stderr",
1057
+ "output_type": "stream",
1058
+ "text": [
1059
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Agent Starting Run: 4qb2bmi8 with config:\n",
1060
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \tbatch_size: 128\n",
1061
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \tdropout: 0.1\n",
1062
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \tepochs: 20\n",
1063
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \tlr: 0.03\n",
1064
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \tn_embd: 16\n",
1065
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \tn_head: 4\n",
1066
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \tn_layers: 6\n"
1067
+ ]
1068
+ },
1069
+ {
1070
+ "data": {
1071
+ "text/html": [
1072
+ "wandb version 0.16.6 is available! To upgrade, please run:\n",
1073
+ " $ pip install wandb --upgrade"
1074
+ ],
1075
+ "text/plain": [
1076
+ "<IPython.core.display.HTML object>"
1077
+ ]
1078
+ },
1079
+ "metadata": {},
1080
+ "output_type": "display_data"
1081
+ },
1082
+ {
1083
+ "data": {
1084
+ "text/html": [
1085
+ "Tracking run with wandb version 0.16.4"
1086
+ ],
1087
+ "text/plain": [
1088
+ "<IPython.core.display.HTML object>"
1089
+ ]
1090
+ },
1091
+ "metadata": {},
1092
+ "output_type": "display_data"
1093
+ },
1094
+ {
1095
+ "data": {
1096
+ "text/html": [
1097
+ "Run data is saved locally in <code>/kaggle/working/wandb/run-20240404_153243-4qb2bmi8</code>"
1098
+ ],
1099
+ "text/plain": [
1100
+ "<IPython.core.display.HTML object>"
1101
+ ]
1102
+ },
1103
+ "metadata": {},
1104
+ "output_type": "display_data"
1105
+ },
1106
+ {
1107
+ "data": {
1108
+ "text/html": [
1109
+ "Syncing run <strong><a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/4qb2bmi8' target=\"_blank\">peach-sweep-3</a></strong> to <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>Sweep page: <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161</a>"
1110
+ ],
1111
+ "text/plain": [
1112
+ "<IPython.core.display.HTML object>"
1113
+ ]
1114
+ },
1115
+ "metadata": {},
1116
+ "output_type": "display_data"
1117
+ },
1118
+ {
1119
+ "data": {
1120
+ "text/html": [
1121
+ " View project at <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers</a>"
1122
+ ],
1123
+ "text/plain": [
1124
+ "<IPython.core.display.HTML object>"
1125
+ ]
1126
+ },
1127
+ "metadata": {},
1128
+ "output_type": "display_data"
1129
+ },
1130
+ {
1131
+ "data": {
1132
+ "text/html": [
1133
+ " View sweep at <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161</a>"
1134
+ ],
1135
+ "text/plain": [
1136
+ "<IPython.core.display.HTML object>"
1137
+ ]
1138
+ },
1139
+ "metadata": {},
1140
+ "output_type": "display_data"
1141
+ },
1142
+ {
1143
+ "data": {
1144
+ "text/html": [
1145
+ " View run at <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/4qb2bmi8' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/4qb2bmi8</a>"
1146
+ ],
1147
+ "text/plain": [
1148
+ "<IPython.core.display.HTML object>"
1149
+ ]
1150
+ },
1151
+ "metadata": {},
1152
+ "output_type": "display_data"
1153
+ },
1154
+ {
1155
+ "name": "stdout",
1156
+ "output_type": "stream",
1157
+ "text": [
1158
+ "48.755 K model parameters\n"
1159
+ ]
1160
+ },
1161
+ {
1162
+ "data": {
1163
+ "application/vnd.jupyter.widget-view+json": {
1164
+ "model_id": "",
1165
+ "version_major": 2,
1166
+ "version_minor": 0
1167
+ },
1168
+ "text/plain": [
1169
+ "VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\\r'), FloatProgress(value=1.0, max=1.0)))"
1170
+ ]
1171
+ },
1172
+ "metadata": {},
1173
+ "output_type": "display_data"
1174
+ },
1175
+ {
1176
+ "data": {
1177
+ "text/html": [
1178
+ "<style>\n",
1179
+ " table.wandb td:nth-child(1) { padding: 0 10px; text-align: left ; width: auto;} td:nth-child(2) {text-align: left ; width: 100%}\n",
1180
+ " .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n",
1181
+ " .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n",
1182
+ " </style>\n",
1183
+ "<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>train_acc</td><td>▁▅▇▇▇███████████████</td></tr><tr><td>train_loss</td><td>█▄▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>val_acc</td><td>▁▆▇▇▇███████████████</td></tr><tr><td>val_loss</td><td>█▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>train_acc</td><td>89.37686</td></tr><tr><td>train_loss</td><td>0.00256</td></tr><tr><td>val_acc</td><td>92.66765</td></tr><tr><td>val_loss</td><td>0.0018</td></tr></table><br/></div></div>"
1184
+ ],
1185
+ "text/plain": [
1186
+ "<IPython.core.display.HTML object>"
1187
+ ]
1188
+ },
1189
+ "metadata": {},
1190
+ "output_type": "display_data"
1191
+ },
1192
+ {
1193
+ "data": {
1194
+ "text/html": [
1195
+ " View run <strong style=\"color:#cdcd00\">peach-sweep-3</strong> at: <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/4qb2bmi8' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/4qb2bmi8</a><br/>Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
1196
+ ],
1197
+ "text/plain": [
1198
+ "<IPython.core.display.HTML object>"
1199
+ ]
1200
+ },
1201
+ "metadata": {},
1202
+ "output_type": "display_data"
1203
+ },
1204
+ {
1205
+ "data": {
1206
+ "text/html": [
1207
+ "Find logs at: <code>./wandb/run-20240404_153243-4qb2bmi8/logs</code>"
1208
+ ],
1209
+ "text/plain": [
1210
+ "<IPython.core.display.HTML object>"
1211
+ ]
1212
+ },
1213
+ "metadata": {},
1214
+ "output_type": "display_data"
1215
+ },
1216
+ {
1217
+ "name": "stderr",
1218
+ "output_type": "stream",
1219
+ "text": [
1220
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Agent Starting Run: gtz48xe5 with config:\n",
1221
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \tbatch_size: 32\n",
1222
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \tdropout: 0\n",
1223
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \tepochs: 30\n",
1224
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \tlr: 0.01\n",
1225
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \tn_embd: 16\n",
1226
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \tn_head: 4\n",
1227
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \tn_layers: 4\n"
1228
+ ]
1229
+ },
1230
+ {
1231
+ "data": {
1232
+ "text/html": [
1233
+ "wandb version 0.16.6 is available! To upgrade, please run:\n",
1234
+ " $ pip install wandb --upgrade"
1235
+ ],
1236
+ "text/plain": [
1237
+ "<IPython.core.display.HTML object>"
1238
+ ]
1239
+ },
1240
+ "metadata": {},
1241
+ "output_type": "display_data"
1242
+ },
1243
+ {
1244
+ "data": {
1245
+ "text/html": [
1246
+ "Tracking run with wandb version 0.16.4"
1247
+ ],
1248
+ "text/plain": [
1249
+ "<IPython.core.display.HTML object>"
1250
+ ]
1251
+ },
1252
+ "metadata": {},
1253
+ "output_type": "display_data"
1254
+ },
1255
+ {
1256
+ "data": {
1257
+ "text/html": [
1258
+ "Run data is saved locally in <code>/kaggle/working/wandb/run-20240404_154533-gtz48xe5</code>"
1259
+ ],
1260
+ "text/plain": [
1261
+ "<IPython.core.display.HTML object>"
1262
+ ]
1263
+ },
1264
+ "metadata": {},
1265
+ "output_type": "display_data"
1266
+ },
1267
+ {
1268
+ "data": {
1269
+ "text/html": [
1270
+ "Syncing run <strong><a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/gtz48xe5' target=\"_blank\">cerulean-sweep-4</a></strong> to <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>Sweep page: <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161</a>"
1271
+ ],
1272
+ "text/plain": [
1273
+ "<IPython.core.display.HTML object>"
1274
+ ]
1275
+ },
1276
+ "metadata": {},
1277
+ "output_type": "display_data"
1278
+ },
1279
+ {
1280
+ "data": {
1281
+ "text/html": [
1282
+ " View project at <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers</a>"
1283
+ ],
1284
+ "text/plain": [
1285
+ "<IPython.core.display.HTML object>"
1286
+ ]
1287
+ },
1288
+ "metadata": {},
1289
+ "output_type": "display_data"
1290
+ },
1291
+ {
1292
+ "data": {
1293
+ "text/html": [
1294
+ " View sweep at <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161</a>"
1295
+ ],
1296
+ "text/plain": [
1297
+ "<IPython.core.display.HTML object>"
1298
+ ]
1299
+ },
1300
+ "metadata": {},
1301
+ "output_type": "display_data"
1302
+ },
1303
+ {
1304
+ "data": {
1305
+ "text/html": [
1306
+ " View run at <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/gtz48xe5' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/gtz48xe5</a>"
1307
+ ],
1308
+ "text/plain": [
1309
+ "<IPython.core.display.HTML object>"
1310
+ ]
1311
+ },
1312
+ "metadata": {},
1313
+ "output_type": "display_data"
1314
+ },
1315
+ {
1316
+ "name": "stdout",
1317
+ "output_type": "stream",
1318
+ "text": [
1319
+ "33.683 K model parameters\n"
1320
+ ]
1321
+ },
1322
+ {
1323
+ "data": {
1324
+ "application/vnd.jupyter.widget-view+json": {
1325
+ "model_id": "",
1326
+ "version_major": 2,
1327
+ "version_minor": 0
1328
+ },
1329
+ "text/plain": [
1330
+ "VBox(children=(Label(value='0.001 MB of 0.047 MB uploaded\\r'), FloatProgress(value=0.028017589156043247, max=1…"
1331
+ ]
1332
+ },
1333
+ "metadata": {},
1334
+ "output_type": "display_data"
1335
+ },
1336
+ {
1337
+ "data": {
1338
+ "text/html": [
1339
+ "<style>\n",
1340
+ " table.wandb td:nth-child(1) { padding: 0 10px; text-align: left ; width: auto;} td:nth-child(2) {text-align: left ; width: 100%}\n",
1341
+ " .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n",
1342
+ " .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n",
1343
+ " </style>\n",
1344
+ "<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>train_acc</td><td>▁▆▆▇▇▇▇▇▇█████████████████████</td></tr><tr><td>train_loss</td><td>█▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>val_acc</td><td>▁▃▄▅▅▆▇▆▇▆▇▇▆▆▆▇▆▇█▇▇▇▇██▇▇▇█▇</td></tr><tr><td>val_loss</td><td>█▆▅▃▃▃▂▂▂▂▂▂▃▂▂▂▃▂▂▂▂▂▂▁▁▂▂▂▁▂</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>train_acc</td><td>92.21615</td></tr><tr><td>train_loss</td><td>0.00725</td></tr><tr><td>val_acc</td><td>93.30009</td></tr><tr><td>val_loss</td><td>0.00663</td></tr></table><br/></div></div>"
1345
+ ],
1346
+ "text/plain": [
1347
+ "<IPython.core.display.HTML object>"
1348
+ ]
1349
+ },
1350
+ "metadata": {},
1351
+ "output_type": "display_data"
1352
+ },
1353
+ {
1354
+ "data": {
1355
+ "text/html": [
1356
+ " View run <strong style=\"color:#cdcd00\">cerulean-sweep-4</strong> at: <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/gtz48xe5' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/gtz48xe5</a><br/>Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
1357
+ ],
1358
+ "text/plain": [
1359
+ "<IPython.core.display.HTML object>"
1360
+ ]
1361
+ },
1362
+ "metadata": {},
1363
+ "output_type": "display_data"
1364
+ },
1365
+ {
1366
+ "data": {
1367
+ "text/html": [
1368
+ "Find logs at: <code>./wandb/run-20240404_154533-gtz48xe5/logs</code>"
1369
+ ],
1370
+ "text/plain": [
1371
+ "<IPython.core.display.HTML object>"
1372
+ ]
1373
+ },
1374
+ "metadata": {},
1375
+ "output_type": "display_data"
1376
+ },
1377
+ {
1378
+ "name": "stderr",
1379
+ "output_type": "stream",
1380
+ "text": [
1381
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Agent Starting Run: aoy7fr9k with config:\n",
1382
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \tbatch_size: 256\n",
1383
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \tdropout: 0.1\n",
1384
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \tepochs: 30\n",
1385
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \tlr: 0.0003\n",
1386
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \tn_embd: 64\n",
1387
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \tn_head: 8\n",
1388
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \tn_layers: 4\n"
1389
+ ]
1390
+ },
1391
+ {
1392
+ "data": {
1393
+ "text/html": [
1394
+ "wandb version 0.16.6 is available! To upgrade, please run:\n",
1395
+ " $ pip install wandb --upgrade"
1396
+ ],
1397
+ "text/plain": [
1398
+ "<IPython.core.display.HTML object>"
1399
+ ]
1400
+ },
1401
+ "metadata": {},
1402
+ "output_type": "display_data"
1403
+ },
1404
+ {
1405
+ "data": {
1406
+ "text/html": [
1407
+ "Tracking run with wandb version 0.16.4"
1408
+ ],
1409
+ "text/plain": [
1410
+ "<IPython.core.display.HTML object>"
1411
+ ]
1412
+ },
1413
+ "metadata": {},
1414
+ "output_type": "display_data"
1415
+ },
1416
+ {
1417
+ "data": {
1418
+ "text/html": [
1419
+ "Run data is saved locally in <code>/kaggle/working/wandb/run-20240404_163029-aoy7fr9k</code>"
1420
+ ],
1421
+ "text/plain": [
1422
+ "<IPython.core.display.HTML object>"
1423
+ ]
1424
+ },
1425
+ "metadata": {},
1426
+ "output_type": "display_data"
1427
+ },
1428
+ {
1429
+ "data": {
1430
+ "text/html": [
1431
+ "Syncing run <strong><a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/aoy7fr9k' target=\"_blank\">warm-sweep-6</a></strong> to <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>Sweep page: <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161</a>"
1432
+ ],
1433
+ "text/plain": [
1434
+ "<IPython.core.display.HTML object>"
1435
+ ]
1436
+ },
1437
+ "metadata": {},
1438
+ "output_type": "display_data"
1439
+ },
1440
+ {
1441
+ "data": {
1442
+ "text/html": [
1443
+ " View project at <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers</a>"
1444
+ ],
1445
+ "text/plain": [
1446
+ "<IPython.core.display.HTML object>"
1447
+ ]
1448
+ },
1449
+ "metadata": {},
1450
+ "output_type": "display_data"
1451
+ },
1452
+ {
1453
+ "data": {
1454
+ "text/html": [
1455
+ " View sweep at <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161</a>"
1456
+ ],
1457
+ "text/plain": [
1458
+ "<IPython.core.display.HTML object>"
1459
+ ]
1460
+ },
1461
+ "metadata": {},
1462
+ "output_type": "display_data"
1463
+ },
1464
+ {
1465
+ "data": {
1466
+ "text/html": [
1467
+ " View run at <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/aoy7fr9k' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/aoy7fr9k</a>"
1468
+ ],
1469
+ "text/plain": [
1470
+ "<IPython.core.display.HTML object>"
1471
+ ]
1472
+ },
1473
+ "metadata": {},
1474
+ "output_type": "display_data"
1475
+ },
1476
+ {
1477
+ "name": "stdout",
1478
+ "output_type": "stream",
1479
+ "text": [
1480
+ "478.595 K model parameters\n"
1481
+ ]
1482
+ }
1483
+ ],
1484
+ "source": [
1485
+ "wandb.agent(sweep_id=sweep_id, function=train)"
1486
+ ]
1487
+ },
1488
+ {
1489
+ "cell_type": "markdown",
1490
+ "metadata": {
1491
+ "id": "cNtTaEc6kxuC"
1492
+ },
1493
+ "source": [
1494
+ "# Test Time\n",
1495
+ "Since this is the best model(validation accuracy) , we will train it on both train and validation data.\n",
1496
+ "We will then test the model on test data"
1497
+ ]
1498
+ },
1499
+ {
1500
+ "cell_type": "markdown",
1501
+ "metadata": {
1502
+ "id": "QcgfjfD9lvWJ"
1503
+ },
1504
+ "source": [
1505
+ "## Best Hyperparameter from validation"
1506
+ ]
1507
+ },
1508
+ {
1509
+ "cell_type": "code",
1510
+ "execution_count": null,
1511
+ "metadata": {
1512
+ "execution": {
1513
+ "iopub.execute_input": "2024-04-06T15:11:46.239015Z",
1514
+ "iopub.status.busy": "2024-04-06T15:11:46.237962Z",
1515
+ "iopub.status.idle": "2024-04-06T15:11:46.337285Z",
1516
+ "shell.execute_reply": "2024-04-06T15:11:46.336384Z",
1517
+ "shell.execute_reply.started": "2024-04-06T15:11:46.238979Z"
1518
+ },
1519
+ "id": "q7SXqJhekxuC",
1520
+ "outputId": "17c0dfd2-2e0b-4449-80fe-9f7a2ce68c28",
1521
+ "trusted": true
1522
+ },
1523
+ "outputs": [
1524
+ {
1525
+ "name": "stdout",
1526
+ "output_type": "stream",
1527
+ "text": [
1528
+ " \n"
1529
+ ]
1530
+ }
1531
+ ],
1532
+ "source": [
1533
+ "n_embd = 128\n",
1534
+ "batch_size = 64\n",
1535
+ "learning_rate = 3e-3\n",
1536
+ "n_head = 8 # other options factors of 32 like 2, 8\n",
1537
+ "n_layers = 6\n",
1538
+ "dropout = 0.1\n",
1539
+ "epochs = 200\n",
1540
+ "\n",
1541
+ "encoder = Encoder(n_embd, n_head, n_layers, dropout)\n",
1542
+ "decoder = Decoder(n_embd, n_head, n_layers, dropout)\n",
1543
+ "encoder.to(device)\n",
1544
+ "decoder.to(device)\n",
1545
+ "print(\" \")"
1546
+ ]
1547
+ },
1548
+ {
1549
+ "cell_type": "markdown",
1550
+ "metadata": {
1551
+ "id": "P0-9k1L6l0iZ"
1552
+ },
1553
+ "source": [
1554
+ "## Train on train_data + val_data"
1555
+ ]
1556
+ },
1557
+ {
1558
+ "cell_type": "code",
1559
+ "execution_count": null,
1560
+ "metadata": {
1561
+ "execution": {
1562
+ "iopub.execute_input": "2024-04-06T15:11:51.054081Z",
1563
+ "iopub.status.busy": "2024-04-06T15:11:51.053142Z",
1564
+ "iopub.status.idle": "2024-04-06T17:55:02.351999Z",
1565
+ "shell.execute_reply": "2024-04-06T17:55:02.350323Z",
1566
+ "shell.execute_reply.started": "2024-04-06T15:11:51.054049Z"
1567
+ },
1568
+ "id": "TQVFJyvlTMjS",
1569
+ "trusted": true
1570
+ },
1571
+ "outputs": [],
1572
+ "source": [
1573
+ "\n",
1574
+ "# print the number of parameters in the model\n",
1575
+ "print(sum([p.numel() for p in encoder.parameters()] + [p.numel() for p in decoder.parameters()])/1e3, 'K model parameters')\n",
1576
+ "\n",
1577
+ "train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)\n",
1578
+ "val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True)\n",
1579
+ "\n",
1580
+ "# create a PyTorch optimizer\n",
1581
+ "encoder_optimizer = torch.optim.AdamW(encoder.parameters(), lr=learning_rate)\n",
1582
+ "decoder_optimizer = torch.optim.AdamW(decoder.parameters(), lr=learning_rate)\n",
1583
+ "\n",
1584
+ "# print('Step | Training Loss | Validation Loss | Training Accuracy % | Validation Accuracy %')\n",
1585
+ "\n",
1586
+ "least_error = float('inf')\n",
1587
+ "patience = 20 # The number of epochs without improvement to wait before stopping\n",
1588
+ "no_improvement = 0\n",
1589
+ "\n",
1590
+ "for i in range(epochs):\n",
1591
+ " running_loss = 0.0\n",
1592
+ " train_correct = 0\n",
1593
+ "\n",
1594
+ " encoder.train()\n",
1595
+ " decoder.train()\n",
1596
+ "\n",
1597
+ " for j,(train_x,train_y) in enumerate(train_loader):\n",
1598
+ " train_x = train_x.to(device)\n",
1599
+ " train_y = train_y.to(device)\n",
1600
+ "\n",
1601
+ " encoder_optimizer.zero_grad(set_to_none=True)\n",
1602
+ " decoder_optimizer.zero_grad(set_to_none=True)\n",
1603
+ "\n",
1604
+ " encoder_output = encoder(train_x)\n",
1605
+ " logits, loss = decoder(train_y[:, :-1], encoder_output, train_y[:, 1:])\n",
1606
+ "\n",
1607
+ " encoder_optimizer.zero_grad(set_to_none=True)\n",
1608
+ " decoder_optimizer.zero_grad(set_to_none=True)\n",
1609
+ " loss.backward()\n",
1610
+ " encoder_optimizer.step()\n",
1611
+ " decoder_optimizer.step()\n",
1612
+ "\n",
1613
+ " running_loss += loss\n",
1614
+ " pred_decoder_output = torch.argmax(logits, dim=-1)\n",
1615
+ " # print(pred_decoder_output, \" target: \", train_y[:, 1:])\n",
1616
+ " train_correct += (pred_decoder_output == train_y[:, 1:]).sum().item()\n",
1617
+ "\n",
1618
+ " for j,(train_x,train_y) in enumerate(val_loader):\n",
1619
+ " train_x = train_x.to(device)\n",
1620
+ " train_y = train_y.to(device)\n",
1621
+ "\n",
1622
+ " encoder_optimizer.zero_grad(set_to_none=True)\n",
1623
+ " decoder_optimizer.zero_grad(set_to_none=True)\n",
1624
+ "\n",
1625
+ " encoder_output = encoder(train_x)\n",
1626
+ " logits, loss = decoder(train_y[:, :-1], encoder_output, train_y[:, 1:])\n",
1627
+ "\n",
1628
+ " encoder_optimizer.zero_grad(set_to_none=True)\n",
1629
+ " decoder_optimizer.zero_grad(set_to_none=True)\n",
1630
+ " loss.backward()\n",
1631
+ " encoder_optimizer.step()\n",
1632
+ " decoder_optimizer.step()\n",
1633
+ "\n",
1634
+ " running_loss += loss\n",
1635
+ " pred_decoder_output = torch.argmax(logits, dim=-1)\n",
1636
+ " # print(pred_decoder_output, \" target: \", train_y[:, 1:])\n",
1637
+ " train_correct += (pred_decoder_output == train_y[:, 1:]).sum().item()\n",
1638
+ "\n",
1639
+ "\n",
1640
+ " metrics = {\n",
1641
+ " \"train_loss\": running_loss.cpu().detach().numpy() / (len(train_data)+len(val_data)),\n",
1642
+ " \"train_acc\": ((train_correct*100) / ((len(train_data)+len(val_data))* (decoder_block_size-1))),\n",
1643
+ " }\n",
1644
+ " if i % 5 == 0:\n",
1645
+ " print(\"Step: \",i)\n",
1646
+ " print(\"train_loss: \", metrics[\"train_loss\"])\n",
1647
+ " print(\"train_acc: \", metrics[\"train_acc\"])"
1648
+ ]
1649
+ },
1650
+ {
1651
+ "cell_type": "code",
1652
+ "execution_count": null,
1653
+ "metadata": {
1654
+ "execution": {
1655
+ "iopub.execute_input": "2024-04-06T00:22:11.853957Z",
1656
+ "iopub.status.busy": "2024-04-06T00:22:11.852912Z",
1657
+ "iopub.status.idle": "2024-04-06T00:22:11.923978Z",
1658
+ "shell.execute_reply": "2024-04-06T00:22:11.923143Z",
1659
+ "shell.execute_reply.started": "2024-04-06T00:22:11.853919Z"
1660
+ },
1661
+ "id": "hAjg5s0IkxuC",
1662
+ "trusted": true
1663
+ },
1664
+ "outputs": [],
1665
+ "source": [
1666
+ "PATH = '/kaggle/working/encoder.pth'\n",
1667
+ "torch.save(encoder, PATH)\n",
1668
+ "PATH = '/kaggle/working/decoder.pth'\n",
1669
+ "torch.save(encoder, PATH)"
1670
+ ]
1671
+ },
1672
+ {
1673
+ "cell_type": "markdown",
1674
+ "metadata": {
1675
+ "id": "x4M3aMxTl-zb"
1676
+ },
1677
+ "source": [
1678
+ "## generate output sequence"
1679
+ ]
1680
+ },
1681
+ {
1682
+ "cell_type": "code",
1683
+ "execution_count": null,
1684
+ "metadata": {
1685
+ "execution": {
1686
+ "iopub.execute_input": "2024-04-06T12:29:19.489092Z",
1687
+ "iopub.status.busy": "2024-04-06T12:29:19.488711Z",
1688
+ "iopub.status.idle": "2024-04-06T12:29:19.496406Z",
1689
+ "shell.execute_reply": "2024-04-06T12:29:19.495353Z",
1690
+ "shell.execute_reply.started": "2024-04-06T12:29:19.489065Z"
1691
+ },
1692
+ "id": "mfIxu6njkxuD",
1693
+ "trusted": true
1694
+ },
1695
+ "outputs": [],
1696
+ "source": [
1697
+ "def generate(input):\n",
1698
+ " B, T = input.shape\n",
1699
+ " encoder_output = encoder(input)\n",
1700
+ " idx = torch.full((B, 1), 2, dtype=torch.long, device=device) # (B,1)\n",
1701
+ "\n",
1702
+ " # idx is (B, T) array of indices in the current context\n",
1703
+ " for _ in range(decoder_block_size-1):\n",
1704
+ " # get the predictions\n",
1705
+ " logits, loss = decoder(idx, encoder_output) # logits (B, T, vocab_size)\n",
1706
+ " # focus only on the last time step\n",
1707
+ " logits = logits[:, -1, :] # becomes (B, C)\n",
1708
+ " # apply softmax to get probabilities\n",
1709
+ " idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (B, 1)\n",
1710
+ " # append sampled index to the running sequence\n",
1711
+ " idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n",
1712
+ " return idx"
1713
+ ]
1714
+ },
1715
+ {
1716
+ "cell_type": "markdown",
1717
+ "metadata": {
1718
+ "id": "BeB2nYeFmXy8"
1719
+ },
1720
+ "source": [
1721
+ "## Check Test Accuracy"
1722
+ ]
1723
+ },
1724
+ {
1725
+ "cell_type": "code",
1726
+ "execution_count": null,
1727
+ "metadata": {
1728
+ "execution": {
1729
+ "iopub.execute_input": "2024-04-06T18:00:25.146854Z",
1730
+ "iopub.status.busy": "2024-04-06T18:00:25.146119Z",
1731
+ "iopub.status.idle": "2024-04-06T18:00:25.156303Z",
1732
+ "shell.execute_reply": "2024-04-06T18:00:25.155453Z",
1733
+ "shell.execute_reply.started": "2024-04-06T18:00:25.146826Z"
1734
+ },
1735
+ "id": "dIzXiSLBkxuD",
1736
+ "outputId": "ebe1d201-32bb-4372-e64a-62ebe173799d",
1737
+ "trusted": true
1738
+ },
1739
+ "outputs": [
1740
+ {
1741
+ "name": "stdout",
1742
+ "output_type": "stream",
1743
+ "text": [
1744
+ "test accuracy(word level) : 67.2188\n"
1745
+ ]
1746
+ }
1747
+ ],
1748
+ "source": [
1749
+ "def check():\n",
1750
+ "## validation code\n",
1751
+ " running_loss_val, val_correct = 0, 0\n",
1752
+ " encoder.eval()\n",
1753
+ " decoder.eval()\n",
1754
+ " test_loader = DataLoader(test_data, batch_size=64, shuffle=True)\n",
1755
+ " for _ in range(50):\n",
1756
+ " val_x,val_y = next(iter(test_loader))\n",
1757
+ "\n",
1758
+ " val_x = val_x.to(device)\n",
1759
+ " val_y = val_y.to(device)\n",
1760
+ "\n",
1761
+ " output = generate(val_x)\n",
1762
+ "\n",
1763
+ " encoder_output = encoder(val_x)\n",
1764
+ " logits, loss = decoder(val_y[:, :-1], encoder_output, val_y[:, 1:])\n",
1765
+ "\n",
1766
+ " running_loss_val += loss\n",
1767
+ " # checking val_correct for the whole sequence\n",
1768
+ " val_correct += torch.sum(torch.sum(output[:, 1:] != val_y[:, 1:], dim=-1) == 0)\n",
1769
+ "\n",
1770
+ " print(\"test accuracy(word level) : \", ((val_correct.cpu().detach().numpy()*100) / len(test_data)))\n",
1771
+ "\n",
1772
+ "check()"
1773
+ ]
1774
+ },
1775
+ {
1776
+ "cell_type": "markdown",
1777
+ "metadata": {
1778
+ "id": "LDP4KvWdFnIL"
1779
+ },
1780
+ "source": [
1781
+ "# Plotting the Attention HeatMaps"
1782
+ ]
1783
+ },
1784
+ {
1785
+ "cell_type": "code",
1786
+ "execution_count": null,
1787
+ "metadata": {
1788
+ "id": "4WfJEdcgFmiI",
1789
+ "trusted": true
1790
+ },
1791
+ "outputs": [],
1792
+ "source": [
1793
+ "import matplotlib.pyplot as plt\n",
1794
+ "import numpy as np\n",
1795
+ "from matplotlib.font_manager import FontProperties\n",
1796
+ "tel_font = FontProperties(fname = 'TiroDevanagariHindi-Regular.ttf')\n",
1797
+ "# Assuming you have attention_weights of shape (batch_size, output_sequence_length, batch_size, input_sequence_length)\n",
1798
+ "# and prediction_matrix of shape (batch_size, output_sequence_length)\n",
1799
+ "# and input_matrix of shape (batch_size, input_sequence_length)\n",
1800
+ "\n",
1801
+ "# Define the grid dimensions\n",
1802
+ "rows = int(np.ceil(np.sqrt(12)))\n",
1803
+ "cols = int(np.ceil(12 / rows))\n",
1804
+ "\n",
1805
+ "# Create a figure and subplots\n",
1806
+ "fig, axes = plt.subplots(rows, cols, figsize=(9, 9))\n",
1807
+ "\n",
1808
+ "for i, ax in enumerate(axes.flatten()):\n",
1809
+ " if i < 12:\n",
1810
+ " prediction = [opLang.index2char[j.item()] for j in pred[i+1]]\n",
1811
+ "\n",
1812
+ " pred_word=\"\"\n",
1813
+ " input_word=\"\"\n",
1814
+ "\n",
1815
+ " for j in range(len(prediction)):\n",
1816
+ " # Ignore padding\n",
1817
+ " if(prediction[j] != '#'):\n",
1818
+ " pred_word += prediction[j]\n",
1819
+ " else :\n",
1820
+ " break\n",
1821
+ " input_seq = [ipLang.index2char[j.item()] for j in testData[i][0]]\n",
1822
+ "\n",
1823
+ " for j in range(len(input_seq)):\n",
1824
+ " if(input_seq[j] != '#'):\n",
1825
+ " input_word += input_seq[j]\n",
1826
+ " else :\n",
1827
+ " break\n",
1828
+ " attn_weights = atten_weights[i, :len(pred_word), :len(input_word)].detach().cpu().numpy()\n",
1829
+ " ax.imshow(attn_weights.T, cmap='hot', interpolation='nearest')\n",
1830
+ " ax.xaxis.set_label_position('top')\n",
1831
+ " ax.set_title(f'Example {i+1}')\n",
1832
+ " ax.set_xlabel('Output predicted')\n",
1833
+ " ax.set_ylabel('Input word')\n",
1834
+ " ax.set_xticks(np.arange(len(pred_word)))\n",
1835
+ " ax.set_xticklabels(pred_word, rotation = 90, fontproperties = tel_font,fontdict={'fontsize':8})\n",
1836
+ " ax.xaxis.tick_top()\n",
1837
+ "\n",
1838
+ " ax.set_yticks(np.arange(len(input_word)))\n",
1839
+ " ax.set_yticklabels(input_word, rotation=90)\n",
1840
+ "\n",
1841
+ "\n",
1842
+ "\n",
1843
+ "# Adjust the spacing between subplots\n",
1844
+ "plt.tight_layout()\n",
1845
+ "\n",
1846
+ "# Show the plot\n",
1847
+ "plt.show()\n",
1848
+ "wandb.init(project='CS6910_Assignment_3')\n",
1849
+ "\n",
1850
+ "# Convert the matplotlib figure to an image\n",
1851
+ "fig.canvas.draw()\n",
1852
+ "image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')\n",
1853
+ "image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))\n",
1854
+ "\n",
1855
+ "# Log the image in wandb\n",
1856
+ "wandb.log({\"attention_heatmaps\": [wandb.Image(image)]})"
1857
+ ]
1858
+ },
1859
+ {
1860
+ "cell_type": "code",
1861
+ "execution_count": null,
1862
+ "metadata": {
1863
+ "id": "FnHR_oql6-S4"
1864
+ },
1865
+ "outputs": [],
1866
+ "source": []
1867
+ }
1868
+ ],
1869
+ "metadata": {
1870
+ "accelerator": "GPU",
1871
+ "colab": {
1872
+ "collapsed_sections": [
1873
+ "hRdpoWePeYHn",
1874
+ "44xIRolL_T_d",
1875
+ "XdltQ7oJCq1j",
1876
+ "GgPU486JC8Mz",
1877
+ "658W9RARGEUf",
1878
+ "q7fAgs5uQni_",
1879
+ "n4rGh7vuQqaa",
1880
+ "nvyRJWUUbR2f",
1881
+ "8ETW0BG_Pa24",
1882
+ "MQPGy32rnD3V",
1883
+ "z_aYZvDD1OHU",
1884
+ "pKvBd5mKf0Hf",
1885
+ "FYMa5jTQRUaB",
1886
+ "zfuv5FoA1wt2",
1887
+ "W7CYNChRGuGK"
1888
+ ],
1889
+ "gpuType": "T4",
1890
+ "include_colab_link": true,
1891
+ "provenance": [],
1892
+ "toc_visible": true
1893
+ },
1894
+ "kaggle": {
1895
+ "accelerator": "gpu",
1896
+ "dataSources": [
1897
+ {
1898
+ "datasetId": 4721249,
1899
+ "sourceId": 8013732,
1900
+ "sourceType": "datasetVersion"
1901
+ }
1902
+ ],
1903
+ "dockerImageVersionId": 30674,
1904
+ "isGpuEnabled": true,
1905
+ "isInternetEnabled": true,
1906
+ "language": "python",
1907
+ "sourceType": "notebook"
1908
+ },
1909
+ "kernelspec": {
1910
+ "display_name": "Python 3",
1911
+ "language": "python",
1912
+ "name": "python3"
1913
+ },
1914
+ "language_info": {
1915
+ "codemirror_mode": {
1916
+ "name": "ipython",
1917
+ "version": 3
1918
+ },
1919
+ "file_extension": ".py",
1920
+ "mimetype": "text/x-python",
1921
+ "name": "python",
1922
+ "nbconvert_exporter": "python",
1923
+ "pygments_lexer": "ipython3",
1924
+ "version": "3.10.13"
1925
+ }
1926
+ },
1927
+ "nbformat": 4,
1928
+ "nbformat_minor": 0
1929
+ }
predictions_attention/predictions.csv ADDED
The diff for this file is too large to render. See raw diff
 
predictions_transformer/predictions.csv ADDED
The diff for this file is too large to render. See raw diff
 
predictions_vanilla/predictions _vanilla.csv ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ annotated-types==0.7.0
3
+ anyio==4.6.0
4
+ asttokens==2.4.1
5
+ certifi==2024.8.30
6
+ charset-normalizer==3.3.2
7
+ click==8.1.7
8
+ colorama==0.4.6
9
+ comm==0.2.2
10
+ contourpy==1.3.0
11
+ cycler==0.12.1
12
+ debugpy==1.8.6
13
+ decorator==5.1.1
14
+ docker-pycreds==0.4.0
15
+ executing==2.1.0
16
+ fastapi==0.115.0
17
+ ffmpy==0.4.0
18
+ filelock==3.16.1
19
+ fonttools==4.54.1
20
+ fsspec==2024.9.0
21
+ gitdb==4.0.11
22
+ GitPython==3.1.43
23
+ gradio==4.44.1
24
+ gradio_client==1.3.0
25
+ h11==0.14.0
26
+ httpcore==1.0.6
27
+ httpx==0.27.2
28
+ huggingface-hub==0.25.1
29
+ idna==3.10
30
+ importlib_resources==6.4.5
31
+ ipykernel==6.29.5
32
+ ipython==8.28.0
33
+ jedi==0.19.1
34
+ Jinja2==3.1.4
35
+ jupyter_client==8.6.3
36
+ jupyter_core==5.7.2
37
+ kiwisolver==1.4.7
38
+ markdown-it-py==3.0.0
39
+ MarkupSafe==2.1.5
40
+ matplotlib==3.9.2
41
+ matplotlib-inline==0.1.7
42
+ mdurl==0.1.2
43
+ mpmath==1.3.0
44
+ nest-asyncio==1.6.0
45
+ networkx==3.3
46
+ numpy==2.1.1
47
+ orjson==3.10.7
48
+ packaging==24.1
49
+ pandas==2.2.3
50
+ parso==0.8.4
51
+ pillow==10.4.0
52
+ platformdirs==4.3.6
53
+ prompt_toolkit==3.0.48
54
+ protobuf==5.28.2
55
+ psutil==6.0.0
56
+ pure_eval==0.2.3
57
+ pydantic==2.9.2
58
+ pydantic_core==2.23.4
59
+ pydub==0.25.1
60
+ Pygments==2.18.0
61
+ pyparsing==3.1.4
62
+ python-dateutil==2.9.0.post0
63
+ python-multipart==0.0.12
64
+ pytz==2024.2
65
+ PyYAML==6.0.2
66
+ pyzmq==26.2.0
67
+ requests==2.32.3
68
+ rich==13.9.1
69
+ ruff==0.6.8
70
+ semantic-version==2.10.0
71
+ sentry-sdk==2.15.0
72
+ setproctitle==1.3.3
73
+ setuptools==75.1.0
74
+ shellingham==1.5.4
75
+ six==1.16.0
76
+ smmap==5.0.1
77
+ sniffio==1.3.1
78
+ stack-data==0.6.3
79
+ starlette==0.38.6
80
+ sympy==1.13.3
81
+ tomlkit==0.12.0
82
+ torch==2.4.1
83
+ tornado==6.4.1
84
+ tqdm==4.66.5
85
+ traitlets==5.14.3
86
+ typer==0.12.5
87
+ typing_extensions==4.12.2
88
+ tzdata==2024.2
89
+ urllib3==2.2.3
90
+ uvicorn==0.31.0
91
+ wandb==0.18.3
92
+ wcwidth==0.2.13
93
+ websockets==12.0
src/decoder.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from src.helper import get_cell
5
+
6
+ class Decoder(nn.Module):
7
+ def __init__(self,
8
+ out_sz: int,
9
+ embed_sz: int,
10
+ hidden_sz: int,
11
+ cell_type: str,
12
+ n_layers: int,
13
+ dropout: float,
14
+ device: str):
15
+
16
+ super(Decoder, self).__init__()
17
+ self.hidden_sz = hidden_sz
18
+ self.n_layers = n_layers
19
+ self.dropout = dropout
20
+ self.cell_type = cell_type
21
+ self.embedding = nn.Embedding(out_sz, embed_sz)
22
+ self.device = device
23
+
24
+ self.rnn = get_cell(cell_type)(input_size = embed_sz,
25
+ hidden_size = hidden_sz,
26
+ num_layers = n_layers,
27
+ dropout = dropout)
28
+
29
+ self.out = nn.Linear(hidden_sz, out_sz)
30
+ self.softmax = nn.LogSoftmax(dim=1)
31
+
32
+ def forward(self, input, hidden, cell):
33
+ output = self.embedding(input).view(1, 1, -1)
34
+ output = F.relu(output)
35
+
36
+ if(self.cell_type == "LSTM"):
37
+ output, (hidden, cell) = self.rnn(output, (hidden, cell))
38
+ else:
39
+ output, hidden = self.rnn(output, hidden)
40
+
41
+ output = self.softmax(self.out(output[0]))
42
+ return output, hidden, cell
43
+
44
+ def initHidden(self):
45
+ return torch.zeros(self.n_layers, 1, self.hidden_sz, device=self.device)
46
+
src/encoder.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ort torch
2
+ import torch.nn as nn
3
+ from src.helper import get_cell
4
+
5
+ class Encoder(nn.Module):
6
+ def __init__(self,
7
+ in_sz: int,
8
+ embed_sz: int,
9
+ hidden_sz: int,
10
+ cell_type: str,
11
+ n_layers: int,
12
+ dropout: float,
13
+ device: str):
14
+
15
+ super(Encoder, self).__init__()
16
+ self.hidden_sz = hidden_sz
17
+ self.n_layers = n_layers
18
+ self.dropout = dropout
19
+ self.cell_type = cell_type
20
+ self.embedding = nn.Embedding(in_sz, embed_sz)
21
+ self.device = device
22
+
23
+ self.rnn = get_cell(cell_type)(input_size = embed_sz,
24
+ hidden_size = hidden_sz,
25
+ num_layers = n_layers,
26
+ dropout = dropout)
27
+
28
+ def forward(self, input, hidden, cell):
29
+ embedded = self.embedding(input).view(1, 1, -1)
30
+
31
+ if(self.cell_type == "LSTM"):
32
+ output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
33
+ else:
34
+ output, hidden = self.rnn(embedded, hidden)
35
+
36
+ return output, hidden, cell
37
+
38
+ def initHidden(self):
39
+ return torch.zeros(self.n_layers, 1, self.hidden_sz, device=self.device)
src/helper.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from src.language import Language, EOS_token
6
+
7
+ def get_data(lang: str, type: str) -> list[list[str]]:
8
+ """
9
+ Returns: 'pairs': list of [input_word, target_word] pairs
10
+ """
11
+ path = "./aksharantar_sampled/{}/{}_{}.csv".format(lang, lang, type)
12
+ df = pd.read_csv(path, header=None)
13
+ pairs = df.values.tolist()
14
+ return pairs
15
+
16
+ def get_languages(lang: str):
17
+ """
18
+ Returns
19
+ 1. input_lang: input language - English
20
+ 2. output_lang: output language - Given language
21
+ 3. pairs: list of [input_word, target_word] pairs
22
+ """
23
+ input_lang = Language('eng')
24
+ output_lang = Language(lang)
25
+ pairs = get_data(lang, "train")
26
+ for pair in pairs:
27
+ input_lang.addWord(pair[0])
28
+ output_lang.addWord(pair[1])
29
+ return input_lang, output_lang, pairs
30
+
31
+ def get_cell(cell_type: str):
32
+ if cell_type == "LSTM":
33
+ return nn.LSTM
34
+ elif cell_type == "GRU":
35
+ return nn.GRU
36
+ elif cell_type == "RNN":
37
+ return nn.RNN
38
+ else:
39
+ raise Exception("Invalid cell type")
40
+
41
+ def get_optimizer(optimizer: str):
42
+ if optimizer == "SGD":
43
+ return optim.SGD
44
+ elif optimizer == "ADAM":
45
+ return optim.Adam
46
+ else:
47
+ raise Exception("Invalid optimizer")
48
+
49
+ def indexesFromWord(lang:Language, word:str):
50
+ return [lang.word2index[char] for char in word]
51
+
52
+ def tensorFromWord(lang:Language, word:str, device:str):
53
+ indexes = indexesFromWord(lang, word)
54
+ indexes.append(EOS_token)
55
+ return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)
56
+
57
+ def tensorsFromPair(input_lang:Language, output_lang:Language, pair:list[str], device:str):
58
+ input_tensor = tensorFromWord(input_lang, pair[0], device)
59
+ target_tensor = tensorFromWord(output_lang, pair[1], device)
60
+ return (input_tensor, target_tensor)
src/language.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Language Model
2
+ SOS_token = 0
3
+ EOS_token = 1
4
+
5
+ class Language:
6
+ def __init__(self, name):
7
+ self.name = name
8
+ self.word2index = {}
9
+ self.word2count = {}
10
+ self.index2word = {SOS_token: "<", EOS_token: ">"}
11
+ self.n_chars = 2 # Count SOS and EOS
12
+
13
+ def addWord(self, word):
14
+ for char in word:
15
+ self.addChar(char)
16
+
17
+ def addChar(self, char):
18
+ if char not in self.word2index:
19
+ self.word2index[char] = self.n_chars
20
+ self.word2count[char] = 1
21
+ self.index2word[self.n_chars] = char
22
+ self.n_chars += 1
23
+ else:
24
+ self.word2count[char] += 1
src/translator.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from src.helper import get_optimizer, tensorsFromPair, get_languages, tensorFromWord, get_data
5
+ from src.language import SOS_token, EOS_token
6
+ from src.encoder import Encoder
7
+ from src.decoder import Decoder
8
+ import random
9
+ import time
10
+ import numpy as np
11
+
12
+ PRINT_EVERY = 5000
13
+ PLOT_EVERY = 100
14
+
15
+ class Translator:
16
+ def __init__(self, lang: str, params: dict, device: str):
17
+ self.lang = lang
18
+ self.input_lang, self.output_lang, self.pairs = get_languages(self.lang)
19
+ self.input_size = self.input_lang.n_chars
20
+ self.output_size = self.output_lang.n_chars
21
+ self.device = device
22
+
23
+ self.training_pairs = [tensorsFromPair(self.input_lang, self.output_lang, pair, self.device) for pair in self.pairs]
24
+
25
+ self.encoder = Encoder(in_sz = self.input_size,
26
+ embed_sz = params["embed_size"],
27
+ hidden_sz = params["hidden_size"],
28
+ cell_type = params["cell_type"],
29
+ n_layers = params["num_layers"],
30
+ dropout = params["dropout"],
31
+ device=self.device).to(self.device)
32
+
33
+ self.decoder = Decoder(out_sz = self.output_size,
34
+ embed_sz = params["embed_size"],
35
+ hidden_sz = params["hidden_size"],
36
+ cell_type = params["cell_type"],
37
+ n_layers = params["num_layers"],
38
+ dropout = params["dropout"],
39
+ device=self.device).to(self.device)
40
+
41
+ self.encoder_optimizer = get_optimizer(params["optimizer"])(self.encoder.parameters(), lr=params["learning_rate"])
42
+ self.decoder_optimizer = get_optimizer(params["optimizer"])(self.decoder.parameters(), lr=params["learning_rate"])
43
+
44
+ self.criterion = nn.NLLLoss()
45
+
46
+ self.teacher_forcing_ratio = params["teacher_forcing_ratio"]
47
+ self.max_length = params["max_length"]
48
+
49
+ def train_single(self, input_tensor, target_tensor):
50
+ encoder_hidden = self.encoder.initHidden()
51
+ encoder_cell = self.encoder.initHidden()
52
+
53
+ self.encoder_optimizer.zero_grad()
54
+ self.decoder_optimizer.zero_grad()
55
+
56
+ input_length = input_tensor.size(0)
57
+ target_length = target_tensor.size(0)
58
+
59
+ encoder_outputs = torch.zeros(self.max_length, self.encoder.hidden_sz, device=self.device)
60
+
61
+ loss = 0
62
+
63
+ for ei in range(input_length):
64
+ encoder_output, encoder_hidden, encoder_cell = self.encoder(input_tensor[ei], encoder_hidden, encoder_cell)
65
+ encoder_outputs[ei] = encoder_output[0, 0]
66
+
67
+ decoder_input = torch.tensor([[SOS_token]], device=self.device)
68
+ decoder_hidden, decoder_cell = encoder_hidden, encoder_cell
69
+
70
+ use_teacher_forcing = True if random.random() < self.teacher_forcing_ratio else False
71
+
72
+ if use_teacher_forcing:
73
+ for di in range(target_length):
74
+ decoder_output, decoder_hidden, decoder_cell = self.decoder(decoder_input, decoder_hidden, decoder_cell)
75
+ loss += self.criterion(decoder_output, target_tensor[di])
76
+
77
+ decoder_input = target_tensor[di]
78
+ else:
79
+ for di in range(target_length):
80
+ decoder_output, decoder_hidden, decoder_cell = self.decoder(decoder_input, decoder_hidden, decoder_cell)
81
+ loss += self.criterion(decoder_output, target_tensor[di])
82
+
83
+ topv, topi = decoder_output.topk(1)
84
+ decoder_input = topi.squeeze().detach()
85
+ if decoder_input.item() == EOS_token:
86
+ break
87
+
88
+ loss.backward()
89
+ self.encoder_optimizer.step()
90
+ self.decoder_optimizer.step()
91
+
92
+ return loss.item() / target_length
93
+
94
+ def train(self, iters=-1):
95
+ start_time = time.time()
96
+ plot_losses = []
97
+ print_loss_total = 0
98
+ plot_loss_total = 0
99
+
100
+ random.shuffle(self.training_pairs)
101
+ iters = len(self.training_pairs) if iters == -1 else iters
102
+
103
+ for iter in range(1, iters+1):
104
+ training_pair = self.training_pairs[iter - 1]
105
+ input_tensor = training_pair[0]
106
+ target_tensor = training_pair[1]
107
+
108
+ loss = self.train_single(input_tensor, target_tensor)
109
+ print_loss_total += loss
110
+ plot_loss_total += loss
111
+
112
+ if iter % PRINT_EVERY == 0:
113
+ print_loss_avg = print_loss_total / PRINT_EVERY
114
+ print_loss_total = 0
115
+ current_time = time.time()
116
+ print("Loss: {:.4f} | Iterations: {} | Time: {:.3f}".format(print_loss_avg, iter, current_time - start_time))
117
+
118
+ if iter % PLOT_EVERY == 0:
119
+ plot_loss_avg = plot_loss_total / PLOT_EVERY
120
+ plot_losses.append(plot_loss_avg)
121
+ plot_loss_total = 0
122
+
123
+ return plot_losses
124
+
125
+ def evaluate(self, word):
126
+ with torch.no_grad():
127
+ input_tensor = tensorFromWord(self.input_lang, word, self.device)
128
+ input_length = input_tensor.size()[0]
129
+ encoder_hidden = self.encoder.initHidden()
130
+ encoder_cell = self.encoder.initHidden()
131
+
132
+ encoder_outputs = torch.zeros(self.max_length, self.encoder.hidden_sz, device=self.device)
133
+
134
+ for ei in range(input_length):
135
+ encoder_output, encoder_hidden, encoder_cell = self.encoder(input_tensor[ei], encoder_hidden, encoder_cell)
136
+ encoder_outputs[ei] += encoder_output[0, 0]
137
+
138
+ decoder_input = torch.tensor([[SOS_token]], device=self.device)
139
+ decoder_hidden, decoder_cell = encoder_hidden, encoder_cell
140
+
141
+ decoded_chars = ""
142
+
143
+ for di in range(self.max_length):
144
+ decoder_output, decoder_hidden, decoder_cell = self.decoder(decoder_input, decoder_hidden, decoder_cell)
145
+ topv, topi = decoder_output.topk(1)
146
+
147
+ if topi.item() == EOS_token:
148
+ break
149
+ else:
150
+ decoded_chars += self.output_lang.index2word[topi.item()]
151
+
152
+ decoder_input = topi.squeeze().detach()
153
+
154
+ return decoded_chars
155
+
156
+ def test_validate(self, type:str):
157
+ pairs = get_data(self.lang, type)
158
+ accuracy = np.sum([self.evaluate(pair[0]) == pair[1] for pair in pairs])
159
+ return accuracy / len(pairs)
test_best_attention.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ import random
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ import pandas as pd
9
+ import time
10
+
11
+ import wandb
12
+ wandb.login()
13
+
14
+ random.seed()
15
+
16
+ import os
17
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
18
+
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ print(device)
21
+
22
+ # Language Model
23
+ SOS_token = 0
24
+ EOS_token = 1
25
+
26
+ class Language:
27
+ def __init__(self, name):
28
+ self.name = name
29
+ self.word2index = {}
30
+ self.word2count = {}
31
+ self.index2word = {SOS_token: "<", EOS_token: ">"}
32
+ self.n_chars = 2 # Count SOS and EOS
33
+
34
+ def addWord(self, word):
35
+ for char in word:
36
+ self.addChar(char)
37
+
38
+ def addChar(self, char):
39
+ if char not in self.word2index:
40
+ self.word2index[char] = self.n_chars
41
+ self.word2count[char] = 1
42
+ self.index2word[self.n_chars] = char
43
+ self.n_chars += 1
44
+ else:
45
+ self.word2count[char] += 1
46
+
47
+ def get_data(lang: str, type: str) -> list[list[str]]:
48
+ """
49
+ Returns: 'pairs': list of [input_word, target_word] pairs
50
+ """
51
+ path = "./aksharantar_sampled/{}/{}_{}.csv".format(lang, lang, type)
52
+ df = pd.read_csv(path, header=None)
53
+ pairs = df.values.tolist()
54
+ return pairs
55
+
56
+ def get_languages(lang: str):
57
+ """
58
+ Returns
59
+ 1. input_lang: input language - English
60
+ 2. output_lang: output language - Given language
61
+ 3. pairs: list of [input_word, target_word] pairs
62
+ """
63
+ input_lang = Language('eng')
64
+ output_lang = Language(lang)
65
+ pairs = get_data(lang, "train")
66
+ for pair in pairs:
67
+ input_lang.addWord(pair[0])
68
+ output_lang.addWord(pair[1])
69
+ return input_lang, output_lang, pairs
70
+
71
+ def get_cell(cell_type: str):
72
+ if cell_type == "LSTM":
73
+ return nn.LSTM
74
+ elif cell_type == "GRU":
75
+ return nn.GRU
76
+ elif cell_type == "RNN":
77
+ return nn.RNN
78
+ else:
79
+ raise Exception("Invalid cell type")
80
+
81
+ def get_optimizer(optimizer: str):
82
+ if optimizer == "SGD":
83
+ return optim.SGD
84
+ elif optimizer == "ADAM":
85
+ return optim.Adam
86
+ else:
87
+ raise Exception("Invalid optimizer")
88
+
89
+ class Encoder(nn.Module):
90
+ def __init__(self,
91
+ in_sz: int,
92
+ embed_sz: int,
93
+ hidden_sz: int,
94
+ cell_type: str,
95
+ n_layers: int,
96
+ dropout: float):
97
+
98
+ super(Encoder, self).__init__()
99
+ self.hidden_sz = hidden_sz
100
+ self.n_layers = n_layers
101
+ self.dropout = dropout
102
+ self.cell_type = cell_type
103
+ self.embedding = nn.Embedding(in_sz, embed_sz)
104
+
105
+ self.rnn = get_cell(cell_type)(input_size = embed_sz,
106
+ hidden_size = hidden_sz,
107
+ num_layers = n_layers,
108
+ dropout = dropout)
109
+
110
+ def forward(self, input, hidden, cell):
111
+ embedded = self.embedding(input).view(1, 1, -1)
112
+
113
+ if(self.cell_type == "LSTM"):
114
+ output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
115
+ else:
116
+ output, hidden = self.rnn(embedded, hidden)
117
+
118
+ return output, hidden, cell
119
+
120
+ def initHidden(self):
121
+ return torch.zeros(self.n_layers, 1, self.hidden_sz, device=device)
122
+
123
+ class AttentionDecoder(nn.Module):
124
+ def __init__(self,
125
+ out_sz: int,
126
+ embed_sz: int,
127
+ hidden_sz: int,
128
+ cell_type: str,
129
+ n_layers: int,
130
+ dropout: float):
131
+
132
+ super(AttentionDecoder, self).__init__()
133
+ self.hidden_sz = hidden_sz
134
+ self.n_layers = n_layers
135
+ self.dropout = dropout
136
+ self.cell_type = cell_type
137
+ self.embedding = nn.Embedding(out_sz, embed_sz)
138
+
139
+ self.attn = nn.Linear(hidden_sz + embed_sz, 50)
140
+ self.attn_combine = nn.Linear(hidden_sz + embed_sz, hidden_sz)
141
+
142
+ self.rnn = get_cell(cell_type)(input_size = hidden_sz,
143
+ hidden_size = hidden_sz,
144
+ num_layers = n_layers,
145
+ dropout = dropout)
146
+
147
+ self.out = nn.Linear(hidden_sz, out_sz)
148
+ self.softmax = nn.LogSoftmax(dim=1)
149
+
150
+ def forward(self, input, hidden, cell, encoder_outputs):
151
+ embedding = self.embedding(input).view(1, 1, -1)
152
+
153
+ attn_weights = F.softmax(self.attn(torch.cat((embedding[0], hidden[0]), 1)), dim=1)
154
+ attn_applied = torch.bmm(attn_weights.unsqueeze(0), encoder_outputs.unsqueeze(0))
155
+
156
+ output = torch.cat((embedding[0], attn_applied[0]), 1)
157
+ output = self.attn_combine(output).unsqueeze(0)
158
+
159
+ if(self.cell_type == "LSTM"):
160
+ output, (hidden, cell) = self.rnn(output, (hidden, cell))
161
+ else:
162
+ output, hidden = self.rnn(output, hidden)
163
+
164
+ output = self.softmax(self.out(output[0]))
165
+ return output, hidden, cell, attn_weights
166
+
167
+ def initHidden(self):
168
+ return torch.zeros(self.n_layers, 1, self.hidden_sz, device=device)
169
+
170
+ def indexesFromWord(lang:Language, word:str):
171
+ return [lang.word2index[char] for char in word]
172
+
173
+ def tensorFromWord(lang:Language, word:str):
174
+ indexes = indexesFromWord(lang, word)
175
+ indexes.append(EOS_token)
176
+ return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)
177
+
178
+ def tensorsFromPair(input_lang:Language, output_lang:Language, pair:list[str]):
179
+ input_tensor = tensorFromWord(input_lang, pair[0])
180
+ target_tensor = tensorFromWord(output_lang, pair[1])
181
+ return (input_tensor, target_tensor)
182
+
183
+ def params_definition():
184
+ """
185
+ params:
186
+
187
+ embed_size : size of embedding (input and output) (8, 16, 32, 64)
188
+ hidden_size : size of hidden layer (64, 128, 256, 512)
189
+ cell_type : type of cell (LSTM, GRU, RNN)
190
+ num_layers : number of layers in encoder (1, 2, 3)
191
+ dropout : dropout probability
192
+ learning_rate : learning rate
193
+ teacher_forcing_ratio : teacher forcing ratio (0.5 fixed for now)
194
+ optimizer : optimizer (SGD, Adam)
195
+ max_length : maximum length of input word (50 fixed for now)
196
+
197
+ """
198
+ pass
199
+
200
+ PRINT_EVERY = 5000
201
+ PLOT_EVERY = 100
202
+
203
+ class Translator:
204
+ def __init__(self, lang: str, params: dict):
205
+ self.lang = lang
206
+ self.input_lang, self.output_lang, self.pairs = get_languages(self.lang)
207
+ self.input_size = self.input_lang.n_chars
208
+ self.output_size = self.output_lang.n_chars
209
+
210
+ self.training_pairs = [tensorsFromPair(self.input_lang, self.output_lang, pair) for pair in self.pairs]
211
+
212
+ self.encoder = Encoder(in_sz = self.input_size,
213
+ embed_sz = params["embed_size"],
214
+ hidden_sz = params["hidden_size"],
215
+ cell_type = params["cell_type"],
216
+ n_layers = params["num_layers"],
217
+ dropout = params["dropout"]).to(device)
218
+
219
+ self.decoder = AttentionDecoder(out_sz = self.output_size,
220
+ embed_sz = params["embed_size"],
221
+ hidden_sz = params["hidden_size"],
222
+ cell_type = params["cell_type"],
223
+ n_layers = params["num_layers"],
224
+ dropout = params["dropout"]).to(device)
225
+
226
+ self.encoder_optimizer = get_optimizer(params["optimizer"])(self.encoder.parameters(), lr=params["learning_rate"], weight_decay=params["weight_decay"])
227
+ self.decoder_optimizer = get_optimizer(params["optimizer"])(self.decoder.parameters(), lr=params["learning_rate"], weight_decay=params["weight_decay"])
228
+
229
+ self.criterion = nn.NLLLoss()
230
+
231
+ self.teacher_forcing_ratio = params["teacher_forcing_ratio"]
232
+ self.max_length = params["max_length"]
233
+
234
+ def train_single(self, input_tensor, target_tensor):
235
+ encoder_hidden = self.encoder.initHidden()
236
+ encoder_cell = self.encoder.initHidden()
237
+
238
+ self.encoder_optimizer.zero_grad()
239
+ self.decoder_optimizer.zero_grad()
240
+
241
+ input_length = input_tensor.size(0)
242
+ target_length = target_tensor.size(0)
243
+
244
+ encoder_outputs = torch.zeros(self.max_length, self.encoder.hidden_sz, device=device)
245
+
246
+ loss = 0
247
+
248
+ for ei in range(input_length):
249
+ encoder_output, encoder_hidden, encoder_cell = self.encoder(input_tensor[ei], encoder_hidden, encoder_cell)
250
+ encoder_outputs[ei] = encoder_output[0, 0]
251
+
252
+ decoder_input = torch.tensor([[SOS_token]], device=device)
253
+ decoder_hidden, decoder_cell = encoder_hidden, encoder_cell
254
+
255
+ use_teacher_forcing = True if random.random() < self.teacher_forcing_ratio else False
256
+
257
+ if use_teacher_forcing:
258
+ for di in range(target_length):
259
+ decoder_output, decoder_hidden, decoder_cell, decoder_attention = self.decoder(decoder_input, decoder_hidden, decoder_cell, encoder_outputs)
260
+ loss += self.criterion(decoder_output, target_tensor[di])
261
+
262
+ decoder_input = target_tensor[di]
263
+ else:
264
+ for di in range(target_length):
265
+ decoder_output, decoder_hidden, decoder_cell, decoder_attention = self.decoder(decoder_input, decoder_hidden, decoder_cell, encoder_outputs)
266
+ loss += self.criterion(decoder_output, target_tensor[di])
267
+
268
+ topv, topi = decoder_output.topk(1)
269
+ decoder_input = topi.squeeze().detach()
270
+ if decoder_input.item() == EOS_token:
271
+ break
272
+
273
+ loss.backward()
274
+ self.encoder_optimizer.step()
275
+ self.decoder_optimizer.step()
276
+
277
+ return loss.item() / target_length
278
+
279
+ def train(self, iters=-1):
280
+ start_time = time.time()
281
+ plot_losses = []
282
+ print_loss_total = 0
283
+ plot_loss_total = 0
284
+
285
+ random.shuffle(self.training_pairs)
286
+ iters = len(self.training_pairs) if iters == -1 else iters
287
+
288
+ for iter in range(1, iters):
289
+ training_pair = self.training_pairs[iter - 1]
290
+ input_tensor = training_pair[0]
291
+ target_tensor = training_pair[1]
292
+
293
+ loss = self.train_single(input_tensor, target_tensor)
294
+ print_loss_total += loss
295
+ plot_loss_total += loss
296
+
297
+ if iter % PRINT_EVERY == 0:
298
+ print_loss_avg = print_loss_total / PRINT_EVERY
299
+ print_loss_total = 0
300
+ current_time = time.time()
301
+ print("Loss: {:.4f} | Iterations: {} | Time: {:.3f}".format(print_loss_avg, iter, current_time - start_time))
302
+
303
+ if iter % PLOT_EVERY == 0:
304
+ plot_loss_avg = plot_loss_total / PLOT_EVERY
305
+ plot_losses.append(plot_loss_avg)
306
+ plot_loss_total = 0
307
+
308
+ return plot_losses
309
+
310
+ def evaluate(self, word):
311
+ with torch.no_grad():
312
+ input_tensor = tensorFromWord(self.input_lang, word)
313
+ input_length = input_tensor.size()[0]
314
+ encoder_hidden = self.encoder.initHidden()
315
+ encoder_cell = self.encoder.initHidden()
316
+
317
+ encoder_outputs = torch.zeros(self.max_length, self.encoder.hidden_sz, device=device)
318
+
319
+ for ei in range(input_length):
320
+ encoder_output, encoder_hidden, encoder_cell = self.encoder(input_tensor[ei], encoder_hidden, encoder_cell)
321
+ encoder_outputs[ei] += encoder_output[0, 0]
322
+
323
+ decoder_input = torch.tensor([[SOS_token]], device=device)
324
+ decoder_hidden, decoder_cell = encoder_hidden, encoder_cell
325
+
326
+ decoded_chars = ""
327
+ decoder_attentions = torch.zeros(self.max_length, self.max_length)
328
+
329
+ for di in range(self.max_length):
330
+ decoder_output, decoder_hidden, decoder_cell, decoder_attention = self.decoder(decoder_input, decoder_hidden, decoder_cell, encoder_outputs)
331
+ decoder_attentions[di] = decoder_attention.data
332
+ topv, topi = decoder_output.topk(1)
333
+
334
+ if topi.item() == EOS_token:
335
+ break
336
+ else:
337
+ decoded_chars += self.output_lang.index2word[topi.item()]
338
+
339
+ decoder_input = topi.squeeze().detach()
340
+
341
+ return decoded_chars, decoder_attentions[:di + 1]
342
+
343
+ def test_validate(self, type:str):
344
+ pairs = get_data(self.lang, type)
345
+ accuracy = 0
346
+ for pair in pairs:
347
+ output, _ = self.evaluate(pair[0])
348
+ if output == pair[1]:
349
+ accuracy += 1
350
+ return accuracy / len(pairs)
351
+
352
+ params = {
353
+ "embed_size": 32,
354
+ "hidden_size": 256,
355
+ "cell_type": "RNN",
356
+ "num_layers": 2,
357
+ "dropout": .1,
358
+ "learning_rate": 0.001,
359
+ "optimizer": "SGD",
360
+ "teacher_forcing_ratio": 0.5,
361
+ "max_length": 50,
362
+ "weight_decay": 0.001
363
+ }
364
+
365
+ model = Translator('tam', params)
366
+
367
+ model.encoder.load_state_dict(torch.load('./best_models_attn/encoder.pt'))
368
+ model.decoder.load_state_dict(torch.load('./best_models_attn/decoder.pt'))
369
+
370
+ with open("test_gen_attn.txt", "w") as f:
371
+ test_data = get_data("tam", "test")
372
+ f.write("Input, Target, Output\n")
373
+ accuracy = 0
374
+ for i in range(len(test_data)):
375
+ word, _ = model.evaluate(test_data[i][0])
376
+ f.write(test_data[i][0] + ", " + test_data[i][1] + ", " + word + "\n")
377
+ if test_data[i][1] == word:
378
+ accuracy += 1
379
+
380
+ print("Test Accuracy: " + str(accuracy/len(test_data) * 100) + "%")
test_best_vanilla.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.translator import Translator
2
+ import torch
3
+ import random
4
+ from src.helper import get_data
5
+
6
+ random.seed()
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+ params = {
10
+ "embed_size": 16,
11
+ "hidden_size": 512,
12
+ "cell_type": "LSTM",
13
+ "num_layers": 2,
14
+ "dropout": 0.1,
15
+ "learning_rate": 0.005,
16
+ "optimizer": "SGD",
17
+ "teacher_forcing_ratio": 0.5,
18
+ "max_length": 50
19
+ }
20
+
21
+ model = Translator("tam", params, device)
22
+
23
+ model.encoder.load_state_dict(torch.load("./best_model_vanilla/encoder.pt"))
24
+ model.decoder.load_state_dict(torch.load("./best_model_vanilla/decoder.pt"))
25
+
26
+ with open("test_gen.txt", "w") as f:
27
+ test_data = get_data("tam", "test")
28
+ f.write("Input, Target, Output\n")
29
+ accuracy = 0
30
+ for i in range(len(test_data)):
31
+ f.write(test_data[i][0] + ", " + test_data[i][1] + ", " + model.evaluate(test_data[i][0]) + "\n")
32
+ if test_data[i][1] == model.evaluate(test_data[i][0]):
33
+ accuracy += 1
34
+
35
+ print("Test Accuracy: " + str(accuracy/len(test_data) * 100) + "%")
train.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.translator import Translator
2
+ import torch
3
+ import random
4
+ import argparse
5
+
6
+ random.seed()
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+ params = {
10
+ "embed_size": 16,
11
+ "hidden_size": 512,
12
+ "cell_type": "LSTM",
13
+ "num_layers": 2,
14
+ "dropout": 0.1,
15
+ "learning_rate": 0.005,
16
+ "optimizer": "SGD",
17
+ "teacher_forcing_ratio": 0.5,
18
+ "max_length": 50
19
+ }
20
+
21
+ language = "tam"
22
+
23
+ # Argument Parser
24
+ parser = argparse.ArgumentParser(description="Transliteration Model")
25
+ parser.add_argument("-es", "--embed_size", type=int, default=16, help="Embedding Size, good_choices = [8, 16, 32]")
26
+ parser.add_argument("-hs", "--hidden_size", type=int, default=512, help="Hidden Size, good_choices = [128, 256, 512]")
27
+ parser.add_argument("-ct", "--cell_type", type=str, default="LSTM", help="Cell Type, choices: [LSTM, GRU, RNN]")
28
+ parser.add_argument("-nl", "--num_layers", type=int, default=2, help="Number of Layers, choices: [1, 2, 3]")
29
+ parser.add_argument("-d", "--dropout", type=float, default=0.1, help="Dropout, good_choices: [0, 0.1, 0.2]")
30
+ parser.add_argument("-lr", "--learning_rate", type=float, default=0.005, help="Learning Rate, good_choices: [0.0005, 0.001, 0.005]")
31
+ parser.add_argument("-o", "--optimizer", type=str, default="SGD", help="Optimizer, choices: [SGD, ADAM]")
32
+ parser.add_argument("-l", "--language", type=str, default="tam", help="Language")
33
+ args = parser.parse_args()
34
+
35
+ params["embed_size"] = args.embed_size
36
+ params["hidden_size"] = args.hidden_size
37
+ params["cell_type"] = args.cell_type
38
+ params["num_layers"] = args.num_layers
39
+ params["dropout"] = args.dropout
40
+ params["learning_rate"] = args.learning_rate
41
+ params["optimizer"] = args.optimizer
42
+ language = args.language
43
+
44
+ model = Translator(language, params, device)
45
+
46
+ print("Training Model")
47
+ print("Language: {}".format(language))
48
+ print("Embedding Size: {}".format(params["embed_size"]))
49
+ print("Hidden Size: {}".format(params["hidden_size"]))
50
+ print("Cell Type: {}".format(params["cell_type"]))
51
+ print("Number of Layers: {}".format(params["num_layers"]))
52
+ print("Dropout: {}".format(params["dropout"]))
53
+ print("Learning Rate: {}".format(params["learning_rate"]))
54
+ print("Optimizer: {}".format(params["optimizer"]))
55
+ print("Teacher Forcing Ratio: {}".format(params["teacher_forcing_ratio"]))
56
+ print("Max Length: {}\n".format(params["max_length"]))
57
+
58
+ epochs = 10
59
+ old_validation_accuracy = 0
60
+
61
+ for epoch in range(epochs):
62
+ print("Epoch: {}".format(epoch + 1))
63
+ plot_losses = model.train()
64
+
65
+ # take average of plot losses as training loss
66
+ training_loss = sum(plot_losses) / len(plot_losses)
67
+
68
+ print("Training Loss: {:.4f}".format(training_loss))
69
+
70
+ training_accuracy = model.test_validate('train')
71
+ print("Training Accuracy: {:.4f}".format(training_accuracy))
72
+
73
+ validation_accuracy = model.test_validate('valid')
74
+ print("Validation Accuracy: {:.4f}".format(validation_accuracy))
75
+
76
+ if epoch > 0:
77
+ if validation_accuracy < 0.0001:
78
+ print("Validation Accuracy is too low. Stopping training.")
79
+ break
80
+
81
+ if validation_accuracy < 0.95 * old_validation_accuracy:
82
+ print("Validation Accuracy is decreasing. Stopping training.")
83
+ break
84
+
85
+ old_validation_accuracy = validation_accuracy
86
+ print("Training Complete")
87
+
88
+ print("Testing Model")
89
+ test_accuracy = model.test_validate('test')
90
+ print("Test Accuracy: {:.4f}".format(test_accuracy))
91
+ print("Testing Complete")
train_attention.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ import random
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ import pandas as pd
9
+ import time
10
+ import argparse
11
+
12
+ random.seed()
13
+
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+ # Language Model
17
+ SOS_token = 0
18
+ EOS_token = 1
19
+
20
+ class Language:
21
+ def __init__(self, name):
22
+ self.name = name
23
+ self.word2index = {}
24
+ self.word2count = {}
25
+ self.index2word = {SOS_token: "<", EOS_token: ">"}
26
+ self.n_chars = 2 # Count SOS and EOS
27
+
28
+ def addWord(self, word):
29
+ for char in word:
30
+ self.addChar(char)
31
+
32
+ def addChar(self, char):
33
+ if char not in self.word2index:
34
+ self.word2index[char] = self.n_chars
35
+ self.word2count[char] = 1
36
+ self.index2word[self.n_chars] = char
37
+ self.n_chars += 1
38
+ else:
39
+ self.word2count[char] += 1
40
+
41
+ def get_data(lang: str, type: str) -> list[list[str]]:
42
+ """
43
+ Returns: 'pairs': list of [input_word, target_word] pairs
44
+ """
45
+ path = "./aksharantar_sampled/{}/{}_{}.csv".format(lang, lang, type)
46
+ df = pd.read_csv(path, header=None)
47
+ pairs = df.values.tolist()
48
+ return pairs
49
+
50
+ def get_languages(lang: str):
51
+ """
52
+ Returns
53
+ 1. input_lang: input language - English
54
+ 2. output_lang: output language - Given language
55
+ 3. pairs: list of [input_word, target_word] pairs
56
+ """
57
+ input_lang = Language('eng')
58
+ output_lang = Language(lang)
59
+ pairs = get_data(lang, "train")
60
+ for pair in pairs:
61
+ input_lang.addWord(pair[0])
62
+ output_lang.addWord(pair[1])
63
+ return input_lang, output_lang, pairs
64
+
65
+ def get_cell(cell_type: str):
66
+ if cell_type == "LSTM":
67
+ return nn.LSTM
68
+ elif cell_type == "GRU":
69
+ return nn.GRU
70
+ elif cell_type == "RNN":
71
+ return nn.RNN
72
+ else:
73
+ raise Exception("Invalid cell type")
74
+
75
+ def get_optimizer(optimizer: str):
76
+ if optimizer == "SGD":
77
+ return optim.SGD
78
+ elif optimizer == "ADAM":
79
+ return optim.Adam
80
+ else:
81
+ raise Exception("Invalid optimizer")
82
+
83
+ class Encoder(nn.Module):
84
+ def __init__(self,
85
+ in_sz: int,
86
+ embed_sz: int,
87
+ hidden_sz: int,
88
+ cell_type: str,
89
+ n_layers: int,
90
+ dropout: float):
91
+
92
+ super(Encoder, self).__init__()
93
+ self.hidden_sz = hidden_sz
94
+ self.n_layers = n_layers
95
+ self.dropout = dropout
96
+ self.cell_type = cell_type
97
+ self.embedding = nn.Embedding(in_sz, embed_sz)
98
+
99
+ self.rnn = get_cell(cell_type)(input_size = embed_sz,
100
+ hidden_size = hidden_sz,
101
+ num_layers = n_layers,
102
+ dropout = dropout)
103
+
104
+ def forward(self, input, hidden, cell):
105
+ embedded = self.embedding(input).view(1, 1, -1)
106
+
107
+ if(self.cell_type == "LSTM"):
108
+ output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
109
+ else:
110
+ output, hidden = self.rnn(embedded, hidden)
111
+
112
+ return output, hidden, cell
113
+
114
+ def initHidden(self):
115
+ return torch.zeros(self.n_layers, 1, self.hidden_sz, device=device)
116
+
117
+ class AttentionDecoder(nn.Module):
118
+ def __init__(self,
119
+ out_sz: int,
120
+ embed_sz: int,
121
+ hidden_sz: int,
122
+ cell_type: str,
123
+ n_layers: int,
124
+ dropout: float):
125
+
126
+ super(AttentionDecoder, self).__init__()
127
+ self.hidden_sz = hidden_sz
128
+ self.n_layers = n_layers
129
+ self.dropout = dropout
130
+ self.cell_type = cell_type
131
+ self.embedding = nn.Embedding(out_sz, embed_sz)
132
+
133
+ self.attn = nn.Linear(hidden_sz + embed_sz, 50)
134
+ self.attn_combine = nn.Linear(hidden_sz + embed_sz, hidden_sz)
135
+
136
+ self.rnn = get_cell(cell_type)(input_size = hidden_sz,
137
+ hidden_size = hidden_sz,
138
+ num_layers = n_layers,
139
+ dropout = dropout)
140
+
141
+ self.out = nn.Linear(hidden_sz, out_sz)
142
+ self.softmax = nn.LogSoftmax(dim=1)
143
+
144
+ def forward(self, input, hidden, cell, encoder_outputs):
145
+ embedding = self.embedding(input).view(1, 1, -1)
146
+
147
+ attn_weights = F.softmax(self.attn(torch.cat((embedding[0], hidden[0]), 1)), dim=1)
148
+ attn_applied = torch.bmm(attn_weights.unsqueeze(0), encoder_outputs.unsqueeze(0))
149
+
150
+ output = torch.cat((embedding[0], attn_applied[0]), 1)
151
+ output = self.attn_combine(output).unsqueeze(0)
152
+
153
+ if(self.cell_type == "LSTM"):
154
+ output, (hidden, cell) = self.rnn(output, (hidden, cell))
155
+ else:
156
+ output, hidden = self.rnn(output, hidden)
157
+
158
+ output = self.softmax(self.out(output[0]))
159
+ return output, hidden, cell, attn_weights
160
+
161
+ def initHidden(self):
162
+ return torch.zeros(self.n_layers, 1, self.hidden_sz, device=device)
163
+
164
+ def indexesFromWord(lang:Language, word:str):
165
+ return [lang.word2index[char] for char in word]
166
+
167
+ def tensorFromWord(lang:Language, word:str):
168
+ indexes = indexesFromWord(lang, word)
169
+ indexes.append(EOS_token)
170
+ return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)
171
+
172
+ def tensorsFromPair(input_lang:Language, output_lang:Language, pair:list[str]):
173
+ input_tensor = tensorFromWord(input_lang, pair[0])
174
+ target_tensor = tensorFromWord(output_lang, pair[1])
175
+ return (input_tensor, target_tensor)
176
+
177
+ def params_definition():
178
+ """
179
+ params:
180
+
181
+ embed_size : size of embedding (input and output) (8, 16, 32, 64)
182
+ hidden_size : size of hidden layer (64, 128, 256, 512)
183
+ cell_type : type of cell (LSTM, GRU, RNN)
184
+ num_layers : number of layers in encoder (1, 2, 3)
185
+ dropout : dropout probability
186
+ learning_rate : learning rate
187
+ teacher_forcing_ratio : teacher forcing ratio (0.5 fixed for now)
188
+ optimizer : optimizer (SGD, Adam)
189
+ max_length : maximum length of input word (50 fixed for now)
190
+
191
+ """
192
+ pass
193
+
194
+ PRINT_EVERY = 5000
195
+ PLOT_EVERY = 100
196
+
197
+ class Translator:
198
+ def __init__(self, lang: str, params: dict):
199
+ self.lang = lang
200
+ self.input_lang, self.output_lang, self.pairs = get_languages(self.lang)
201
+ self.input_size = self.input_lang.n_chars
202
+ self.output_size = self.output_lang.n_chars
203
+
204
+ self.training_pairs = [tensorsFromPair(self.input_lang, self.output_lang, pair) for pair in self.pairs]
205
+
206
+ self.encoder = Encoder(in_sz = self.input_size,
207
+ embed_sz = params["embed_size"],
208
+ hidden_sz = params["hidden_size"],
209
+ cell_type = params["cell_type"],
210
+ n_layers = params["num_layers"],
211
+ dropout = params["dropout"]).to(device)
212
+
213
+ self.decoder = AttentionDecoder(out_sz = self.output_size,
214
+ embed_sz = params["embed_size"],
215
+ hidden_sz = params["hidden_size"],
216
+ cell_type = params["cell_type"],
217
+ n_layers = params["num_layers"],
218
+ dropout = params["dropout"]).to(device)
219
+
220
+ self.encoder_optimizer = get_optimizer(params["optimizer"])(self.encoder.parameters(), lr=params["learning_rate"], weight_decay=params["weight_decay"])
221
+ self.decoder_optimizer = get_optimizer(params["optimizer"])(self.decoder.parameters(), lr=params["learning_rate"], weight_decay=params["weight_decay"])
222
+
223
+ self.criterion = nn.NLLLoss()
224
+
225
+ self.teacher_forcing_ratio = params["teacher_forcing_ratio"]
226
+ self.max_length = params["max_length"]
227
+
228
+ def train_single(self, input_tensor, target_tensor):
229
+ encoder_hidden = self.encoder.initHidden()
230
+ encoder_cell = self.encoder.initHidden()
231
+
232
+ self.encoder_optimizer.zero_grad()
233
+ self.decoder_optimizer.zero_grad()
234
+
235
+ input_length = input_tensor.size(0)
236
+ target_length = target_tensor.size(0)
237
+
238
+ encoder_outputs = torch.zeros(self.max_length, self.encoder.hidden_sz, device=device)
239
+
240
+ loss = 0
241
+
242
+ for ei in range(input_length):
243
+ encoder_output, encoder_hidden, encoder_cell = self.encoder(input_tensor[ei], encoder_hidden, encoder_cell)
244
+ encoder_outputs[ei] = encoder_output[0, 0]
245
+
246
+ decoder_input = torch.tensor([[SOS_token]], device=device)
247
+ decoder_hidden, decoder_cell = encoder_hidden, encoder_cell
248
+
249
+ use_teacher_forcing = True if random.random() < self.teacher_forcing_ratio else False
250
+
251
+ if use_teacher_forcing:
252
+ for di in range(target_length):
253
+ decoder_output, decoder_hidden, decoder_cell, decoder_attention = self.decoder(decoder_input, decoder_hidden, decoder_cell, encoder_outputs)
254
+ loss += self.criterion(decoder_output, target_tensor[di])
255
+
256
+ decoder_input = target_tensor[di]
257
+ else:
258
+ for di in range(target_length):
259
+ decoder_output, decoder_hidden, decoder_cell, decoder_attention = self.decoder(decoder_input, decoder_hidden, decoder_cell, encoder_outputs)
260
+ loss += self.criterion(decoder_output, target_tensor[di])
261
+
262
+ topv, topi = decoder_output.topk(1)
263
+ decoder_input = topi.squeeze().detach()
264
+ if decoder_input.item() == EOS_token:
265
+ break
266
+
267
+ loss.backward()
268
+ self.encoder_optimizer.step()
269
+ self.decoder_optimizer.step()
270
+
271
+ return loss.item() / target_length
272
+
273
+ def train(self, iters=-1):
274
+ start_time = time.time()
275
+ plot_losses = []
276
+ print_loss_total = 0
277
+ plot_loss_total = 0
278
+
279
+ random.shuffle(self.training_pairs)
280
+ iters = len(self.training_pairs) if iters == -1 else iters
281
+
282
+ for iter in range(1, iters):
283
+ training_pair = self.training_pairs[iter - 1]
284
+ input_tensor = training_pair[0]
285
+ target_tensor = training_pair[1]
286
+
287
+ loss = self.train_single(input_tensor, target_tensor)
288
+ print_loss_total += loss
289
+ plot_loss_total += loss
290
+
291
+ if iter % PRINT_EVERY == 0:
292
+ print_loss_avg = print_loss_total / PRINT_EVERY
293
+ print_loss_total = 0
294
+ current_time = time.time()
295
+ print("Loss: {:.4f} | Iterations: {} | Time: {:.3f}".format(print_loss_avg, iter, current_time - start_time))
296
+
297
+ if iter % PLOT_EVERY == 0:
298
+ plot_loss_avg = plot_loss_total / PLOT_EVERY
299
+ plot_losses.append(plot_loss_avg)
300
+ plot_loss_total = 0
301
+
302
+ return plot_losses
303
+
304
+ def evaluate(self, word):
305
+ with torch.no_grad():
306
+ input_tensor = tensorFromWord(self.input_lang, word)
307
+ input_length = input_tensor.size()[0]
308
+ encoder_hidden = self.encoder.initHidden()
309
+ encoder_cell = self.encoder.initHidden()
310
+
311
+ encoder_outputs = torch.zeros(self.max_length, self.encoder.hidden_sz, device=device)
312
+
313
+ for ei in range(input_length):
314
+ encoder_output, encoder_hidden, encoder_cell = self.encoder(input_tensor[ei], encoder_hidden, encoder_cell)
315
+ encoder_outputs[ei] += encoder_output[0, 0]
316
+
317
+ decoder_input = torch.tensor([[SOS_token]], device=device)
318
+ decoder_hidden, decoder_cell = encoder_hidden, encoder_cell
319
+
320
+ decoded_chars = ""
321
+ decoder_attentions = torch.zeros(self.max_length, self.max_length)
322
+
323
+ for di in range(self.max_length):
324
+ decoder_output, decoder_hidden, decoder_cell, decoder_attention = self.decoder(decoder_input, decoder_hidden, decoder_cell, encoder_outputs)
325
+ decoder_attentions[di] = decoder_attention.data
326
+ topv, topi = decoder_output.topk(1)
327
+
328
+ if topi.item() == EOS_token:
329
+ break
330
+ else:
331
+ decoded_chars += self.output_lang.index2word[topi.item()]
332
+
333
+ decoder_input = topi.squeeze().detach()
334
+
335
+ return decoded_chars, decoder_attentions[:di + 1]
336
+
337
+ def test_validate(self, type:str):
338
+ pairs = get_data(self.lang, type)
339
+ accuracy = 0
340
+ for pair in pairs:
341
+ output, _ = self.evaluate(pair[0])
342
+ if output == pair[1]:
343
+ accuracy += 1
344
+ return accuracy / len(pairs)
345
+
346
+ params = {
347
+ "embed_size": 32,
348
+ "hidden_size": 256,
349
+ "cell_type": "RNN",
350
+ "num_layers": 2,
351
+ "dropout": 0,
352
+ "learning_rate": 0.001,
353
+ "optimizer": "SGD",
354
+ "teacher_forcing_ratio": 0.5,
355
+ "max_length": 50,
356
+ "weight_decay": 0.001
357
+ }
358
+
359
+ language = "tam"
360
+
361
+ parser = argparse.ArgumentParser(description="Transliteration Model with Attention")
362
+ parser.add_argument('-es', '--embed_size', type=int, default=32, help='Embedding size')
363
+ parser.add_argument('-hs', '--hidden_size', type=int, default=256, help='Hidden size')
364
+ parser.add_argument('-ct', '--cell_type', type=str, default='RNN', help='Cell type')
365
+ parser.add_argument('-nl', '--num_layers', type=int, default=2, help='Number of layers')
366
+ parser.add_argument('-dr', '--dropout', type=float, default=0, help='Dropout')
367
+ parser.add_argument('-lr', '--learning_rate', type=float, default=0.001, help='Learning rate')
368
+ parser.add_argument('-op', '--optimizer', type=str, default='SGD', help='Optimizer')
369
+ parser.add_argument('-wd', '--weight_decay', type=float, default=0.001, help='Weight decay')
370
+ parser.add_argument('-l', '--lang', type=str, default='tam', help='Language')
371
+
372
+ args = parser.parse_args()
373
+
374
+ for arg in vars(args):
375
+ params[arg] = getattr(args, arg)
376
+
377
+ language = args.lang
378
+
379
+ print("Language: {}".format(language))
380
+ print("Embedding size: {}".format(params['embed_size']))
381
+ print("Hidden size: {}".format(params['hidden_size']))
382
+ print("Cell type: {}".format(params['cell_type']))
383
+ print("Number of layers: {}".format(params['num_layers']))
384
+ print("Dropout: {}".format(params['dropout']))
385
+ print("Learning rate: {}".format(params['learning_rate']))
386
+ print("Optimizer: {}".format(params['optimizer']))
387
+ print("Weight decay: {}".format(params['weight_decay']))
388
+ print("Teacher forcing ratio: {}".format(params['teacher_forcing_ratio']))
389
+ print("Max length: {}".format(params['max_length']))
390
+
391
+ model = Translator(language, params)
392
+
393
+ epochs = 10
394
+
395
+ for epoch in range(epochs):
396
+ print("Epoch: {}".format(epoch + 1))
397
+ model.train()
398
+
399
+ train_accuracy = model.test_validate('train')
400
+ print("Training Accuracy: {:.4f}".format(train_accuracy))
401
+
402
+ validation_accuracy = model.test_validate('valid')
403
+ print("Validation Accuracy: {:.4f}".format(validation_accuracy))
404
+
405
+ test_accuracy = model.test_validate('test')
406
+ print("Test Accuracy: {:.4f}".format(test_accuracy))