Spaces:
Running
Running
Pankaj Singh Rawat
commited on
Commit
·
9e582c5
0
Parent(s):
Initial commit
Browse files- .gitattributes +35 -0
- .gitignore +3 -0
- README copy.md +130 -0
- README.md +12 -0
- app.py +23 -0
- fast_api.py +38 -0
- inference/__init__.py +0 -0
- inference/decoder (1).pth +3 -0
- inference/demo.ipynb +185 -0
- inference/encoder (1).pth +3 -0
- inference/input_lang.pkl +3 -0
- inference/language.py +25 -0
- inference/output_lang.pkl +3 -0
- inference/transformer.py +54 -0
- inference/utility.py +172 -0
- main.py +26 -0
- notebooks/encoder_decoder_RNNs.ipynb +1924 -0
- notebooks/transformers.ipynb +1929 -0
- predictions_attention/predictions.csv +0 -0
- predictions_transformer/predictions.csv +0 -0
- predictions_vanilla/predictions _vanilla.csv +0 -0
- requirements.txt +93 -0
- src/decoder.py +46 -0
- src/encoder.py +39 -0
- src/helper.py +60 -0
- src/language.py +24 -0
- src/translator.py +159 -0
- test_best_attention.py +380 -0
- test_best_vanilla.py +35 -0
- train.py +91 -0
- train_attention.py +406 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
wandb
|
2 |
+
__pycache__
|
3 |
+
transliteration
|
README copy.md
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Sequence to Sequence Language Transliteration using RNNs and Transformers
|
2 |
+
|
3 |
+
This repository contains the files for the third assignment of the course CS6910 - Deep Learning at IIT Madras.
|
4 |
+
|
5 |
+
The transformers part was added later and was not part of the assignment.
|
6 |
+
|
7 |
+
Implemented a Encoder Decoder Architecture with/without Attention Mechanism, and later with Transformers, and used then to perform Transliteration on the Akshanrankar Dataset(Englist-Hindi transliteration pairs) provided. These models where built using RNN, LSTM and GRU cells provided by PyTorch.
|
8 |
+
|
9 |
+
Transformers architecture is built from scratch following the "Attention is All You Need" paper. Used basic feed forward and embeddings layers from pyTorch.
|
10 |
+
|
11 |
+
Jump to Section: [Usage](#usage)
|
12 |
+
|
13 |
+
Report: [Report](https://wandb.ai/iitmadras/CS6910_Assignment_3/reports/CS6910-Assignment-3-Report--Vmlldzo0MzQyNDk5)
|
14 |
+
|
15 |
+
## Encoder
|
16 |
+
|
17 |
+
The encoder is a simple cell of either LSTM, RNN or GRU. The input to the encoder is a sequence of characters and the output is a sequence of hidden states. The hidden state of the last time step is used as the context vector for the decoder.
|
18 |
+
|
19 |
+
Encoder can also be a transformer encoder with multiple layers containing self-attention mechanism. The output generated by the encoder is fed to the decoder of transformers.
|
20 |
+
|
21 |
+
## Decoder
|
22 |
+
|
23 |
+
The decoder is again a simple cell of either LSTM, RNN or GRU. The input to the decoder is the hidden state of the encoder and the output of the previous time step. The output of the decoder is a sequence of characters. The decoder has an additional fully connected layer and a log softmax which is used to predict the next character.
|
24 |
+
|
25 |
+
Decoder can also be a transformer decoder with multiple layers containing masked self-attention and masked cross-attention mechanism. The output generated by the encoder is fed as input to the decoder of transformers. Next character prediction model is used to generate the complete target sequence in Hindi.
|
26 |
+
|
27 |
+
## Attention Mechanism
|
28 |
+
|
29 |
+
The attention mechanism is implemented using the dot product attention mechanism. The attentions are calulated by a weighted sum of softmax values of dot products of the hidden states of the decoder and the hidden states of the encoder. The attention values are then concatenated with the hidden states of the decoder and passed through a fully connected layer to get the output of the decoder.
|
30 |
+
|
31 |
+
## Dataset
|
32 |
+
|
33 |
+
The dataset used is the Aksharankar Dataset provided by the course. The dataset contains 3 files, namely, `train.csv`, `valid.csv` and `test.csv` for each language for a subset of indian languages. I have used the Tamil dataset for this assignment. The dataset contains 2 columns, namely, `English` and `Hindi` words which are the input and output strings respectively.
|
34 |
+
|
35 |
+
## Used Python Libraries and Version
|
36 |
+
|
37 |
+
- Python 3.10.9
|
38 |
+
- Pytorch 1.13.1
|
39 |
+
- Pandas 1.5.3
|
40 |
+
|
41 |
+
## Usage
|
42 |
+
|
43 |
+
To run the training code for the standard encoder decoder architecture using the best set of hyperparameters, run the following command:
|
44 |
+
|
45 |
+
```bash
|
46 |
+
python3 train.py
|
47 |
+
```
|
48 |
+
|
49 |
+
To run the training code for the encoder decoder architecture with attention mechanism using the best set of hyperparameters, run the following command:
|
50 |
+
|
51 |
+
```bash
|
52 |
+
python3 train_attention.py
|
53 |
+
```
|
54 |
+
|
55 |
+
To run the inference code for the standard encoder decoder architecture using the best set of hyperparameters, run the following command: (This uses the state dicts stored in the best_models folder and creates a file named test_gen.txt with the test predictions)
|
56 |
+
|
57 |
+
```bash
|
58 |
+
python3 test_best_vanilla.py
|
59 |
+
```
|
60 |
+
|
61 |
+
To run the inference code for the encoder decoder architecture with attention mechanism using the best set of hyperparameters, run the following command: (This uses the state dicts stored in the best_models folder and creates a file named test_gen.txt with the test predictions)
|
62 |
+
|
63 |
+
```bash
|
64 |
+
python3 test_best_attention.py
|
65 |
+
```
|
66 |
+
|
67 |
+
To run with custom hyperparameters, run the following command:
|
68 |
+
|
69 |
+
```bash
|
70 |
+
python3 train.py -h
|
71 |
+
```
|
72 |
+
|
73 |
+
```bash
|
74 |
+
# The output of the above command is as follows:
|
75 |
+
usage: train.py [-h]
|
76 |
+
[-es EMBED_SIZE]
|
77 |
+
[-hs HIDDEN_SIZE]
|
78 |
+
[-ct CELL_TYPE]
|
79 |
+
[-nl NUM_LAYERS]
|
80 |
+
[-d DROPOUT]
|
81 |
+
[-lr LEARNING_RATE]
|
82 |
+
[-o OPTIMIZER]
|
83 |
+
[-l LANGUAGE]
|
84 |
+
|
85 |
+
Transliteration Model
|
86 |
+
|
87 |
+
options:
|
88 |
+
-h, --help show this help message and exit
|
89 |
+
-es EMBED_SIZE, --embed_size EMBED_SIZE Embedding Size, good_choices = [8, 16, 32]
|
90 |
+
-hs HIDDEN_SIZE, --hidden_size HIDDEN_SIZE Hidden Size, good_choices = [128, 256, 512]
|
91 |
+
-ct CELL_TYPE, --cell_type CELL_TYPE Cell Type, choices: [LSTM, GRU, RNN]
|
92 |
+
-nl NUM_LAYERS, --num_layers NUM_LAYERS Number of Layers, choices: [1, 2, 3]
|
93 |
+
-d DROPOUT, --dropout DROPOUT Dropout, good_choices: [0, 0.1, 0.2]
|
94 |
+
-lr LEARNING_RATE, --learning_rate LEARNING_RATE Learning Rate, good_choices: [0.0005, 0.001, 0.005]
|
95 |
+
-o OPTIMIZER, --optimizer OPTIMIZER Optimizer, choices: [SGD, ADAM]
|
96 |
+
-l LANGUAGE, --language LANGUAGE Language
|
97 |
+
```
|
98 |
+
|
99 |
+
To run the training code for the attention mechanism with custom hyperparameters, run the following command:
|
100 |
+
|
101 |
+
```bash
|
102 |
+
python3 train_attention.py -h
|
103 |
+
```
|
104 |
+
|
105 |
+
```bash
|
106 |
+
usage: train_attention.py [-h]
|
107 |
+
[-es EMBED_SIZE]
|
108 |
+
[-hs HIDDEN_SIZE]
|
109 |
+
[-ct CELL_TYPE]
|
110 |
+
[-nl NUM_LAYERS]
|
111 |
+
[-dr DROPOUT]
|
112 |
+
[-lr LEARNING_RATE]
|
113 |
+
[-op OPTIMIZER]
|
114 |
+
[-wd WEIGHT_DECAY]
|
115 |
+
[-l LANG]
|
116 |
+
|
117 |
+
Transliteration Model with Attention
|
118 |
+
|
119 |
+
options:
|
120 |
+
-h, --help show this help message and exit
|
121 |
+
-es EMBED_SIZE, --embed_size EMBED_SIZE Embedding size
|
122 |
+
-hs HIDDEN_SIZE, --hidden_size HIDDEN_SIZE Hidden size
|
123 |
+
-ct CELL_TYPE, --cell_type CELL_TYPE Cell type
|
124 |
+
-nl NUM_LAYERS, --num_layers NUM_LAYERS Number of layers
|
125 |
+
-dr DROPOUT, --dropout DROPOUT Dropout
|
126 |
+
-lr LEARNING_RATE, --learning_rate LEARNING_RATE Learning rate
|
127 |
+
-op OPTIMIZER, --optimizer OPTIMIZER Optimizer
|
128 |
+
-wd WEIGHT_DECAY, --weight_decay WEIGHT_DECAY Weight decay
|
129 |
+
-l LANG, --lang LANG Language
|
130 |
+
```
|
README.md
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Transliteration
|
3 |
+
emoji: 📈
|
4 |
+
colorFrom: green
|
5 |
+
colorTo: gray
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.44.1
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
---
|
11 |
+
|
12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from inference.language import Language
|
3 |
+
from inference.utility import Encoder, Decoder, encoderBlock, decoderBlock, MultiHeadAttention, Head, FeedForward
|
4 |
+
from inference.transformer import generate
|
5 |
+
|
6 |
+
# Function to call the FastAPI backend
|
7 |
+
def predict(user_input):
|
8 |
+
# Prepare the data to send to the FastAPI API
|
9 |
+
input = user_input.split(" ")
|
10 |
+
|
11 |
+
result = generate(input)
|
12 |
+
|
13 |
+
# Extract the answer
|
14 |
+
return " ".join(result)
|
15 |
+
|
16 |
+
|
17 |
+
# Launch the Gradio interface
|
18 |
+
if __name__ == "__main__":
|
19 |
+
gr.Interface(predict,
|
20 |
+
inputs=gr.Textbox(placeholder="Your Hinglish text"),
|
21 |
+
outputs=gr.Textbox(placeholder="Output Hindi text"),
|
22 |
+
description="A English to Hindi Transliteration app",
|
23 |
+
examples=["namaste aapko", "kese ho aap, sab badiya"]).launch(share=False)
|
fast_api.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException
|
2 |
+
from pydantic import BaseModel
|
3 |
+
from inference.language import Language
|
4 |
+
from inference.utility import Encoder, Decoder, encoderBlock, decoderBlock, MultiHeadAttention, Head, FeedForward
|
5 |
+
from inference.transformer import generate
|
6 |
+
from typing import List
|
7 |
+
import uvicorn
|
8 |
+
|
9 |
+
|
10 |
+
# Initialize FastAPI app
|
11 |
+
app = FastAPI()
|
12 |
+
|
13 |
+
# Create a request model to define the input for the transliteration pipeline
|
14 |
+
class TransRequest(BaseModel):
|
15 |
+
query: str
|
16 |
+
|
17 |
+
# Create a response model to define the output of the RAG pipeline
|
18 |
+
class TransResponse(BaseModel):
|
19 |
+
response: List[str]
|
20 |
+
|
21 |
+
# Define a FastAPI endpoint for transliteration pipeline
|
22 |
+
@app.post("/trans", response_model=TransResponse)
|
23 |
+
async def get_transliteration(request: TransRequest):
|
24 |
+
try:
|
25 |
+
# Call the RAG pipeline function with the query
|
26 |
+
input = request.query.split(" ")
|
27 |
+
result = generate(input)
|
28 |
+
|
29 |
+
return TransResponse(
|
30 |
+
response=result
|
31 |
+
)
|
32 |
+
except Exception as e:
|
33 |
+
# In case of an error, return an HTTPException with a 500 status code
|
34 |
+
raise HTTPException(status_code=500, detail=str(e))
|
35 |
+
|
36 |
+
# Run the FastAPI application (for local testing)
|
37 |
+
if __name__ == "__main__":
|
38 |
+
uvicorn.run(app, host="127.0.0.1", port=8000)
|
inference/__init__.py
ADDED
File without changes
|
inference/decoder (1).pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:66c2e34a0036672d568c9b5faa09bcd80f75f0edeea6053aeacb20b8094aace6
|
3 |
+
size 14137430
|
inference/demo.ipynb
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 23,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import torch\n",
|
10 |
+
"import pickle\n",
|
11 |
+
"from language import Language\n",
|
12 |
+
"from utility import Encoder, Decoder, encoderBlock, decoderBlock, MultiHeadAttention, Head, FeedForward\n",
|
13 |
+
"import warnings\n",
|
14 |
+
"from typing import List\n",
|
15 |
+
"warnings.filterwarnings(\"ignore\", category=FutureWarning)\n"
|
16 |
+
]
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"cell_type": "code",
|
20 |
+
"execution_count": 43,
|
21 |
+
"metadata": {},
|
22 |
+
"outputs": [],
|
23 |
+
"source": []
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"cell_type": "code",
|
27 |
+
"execution_count": 44,
|
28 |
+
"metadata": {},
|
29 |
+
"outputs": [
|
30 |
+
{
|
31 |
+
"data": {
|
32 |
+
"text/plain": [
|
33 |
+
"['लॉक्सलाक्राक्यालालासी']"
|
34 |
+
]
|
35 |
+
},
|
36 |
+
"execution_count": 44,
|
37 |
+
"metadata": {},
|
38 |
+
"output_type": "execute_result"
|
39 |
+
}
|
40 |
+
],
|
41 |
+
"source": [
|
42 |
+
"s = 'a' * 1\n",
|
43 |
+
"generate([s])"
|
44 |
+
]
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"cell_type": "code",
|
48 |
+
"execution_count": 39,
|
49 |
+
"metadata": {},
|
50 |
+
"outputs": [
|
51 |
+
{
|
52 |
+
"data": {
|
53 |
+
"text/plain": [
|
54 |
+
"'^थ्रालाष्राप्टोार्फ्रास्रफ्फ्फ्'"
|
55 |
+
]
|
56 |
+
},
|
57 |
+
"execution_count": 39,
|
58 |
+
"metadata": {},
|
59 |
+
"output_type": "execute_result"
|
60 |
+
}
|
61 |
+
],
|
62 |
+
"source": [
|
63 |
+
"output_lang.decode(o.tolist()[0])"
|
64 |
+
]
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"cell_type": "code",
|
68 |
+
"execution_count": 28,
|
69 |
+
"metadata": {},
|
70 |
+
"outputs": [
|
71 |
+
{
|
72 |
+
"data": {
|
73 |
+
"text/plain": [
|
74 |
+
"tensor([20, 4, 5, 12, 4, 3])"
|
75 |
+
]
|
76 |
+
},
|
77 |
+
"execution_count": 28,
|
78 |
+
"metadata": {},
|
79 |
+
"output_type": "execute_result"
|
80 |
+
}
|
81 |
+
],
|
82 |
+
"source": [
|
83 |
+
"s = \"pankaj\"\n",
|
84 |
+
"torch.tensor(input_lang.encode(s), device=device, dtype=torch.long)"
|
85 |
+
]
|
86 |
+
},
|
87 |
+
{
|
88 |
+
"cell_type": "code",
|
89 |
+
"execution_count": 5,
|
90 |
+
"metadata": {},
|
91 |
+
"outputs": [
|
92 |
+
{
|
93 |
+
"name": "stdout",
|
94 |
+
"output_type": "stream",
|
95 |
+
"text": [
|
96 |
+
"Running on local URL: http://127.0.0.1:7864\n",
|
97 |
+
"\n",
|
98 |
+
"Could not create share link. Please check your internet connection or our status page: https://status.gradio.app.\n"
|
99 |
+
]
|
100 |
+
},
|
101 |
+
{
|
102 |
+
"data": {
|
103 |
+
"text/html": [
|
104 |
+
"<div><iframe src=\"http://127.0.0.1:7864/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
|
105 |
+
],
|
106 |
+
"text/plain": [
|
107 |
+
"<IPython.core.display.HTML object>"
|
108 |
+
]
|
109 |
+
},
|
110 |
+
"metadata": {},
|
111 |
+
"output_type": "display_data"
|
112 |
+
}
|
113 |
+
],
|
114 |
+
"source": [
|
115 |
+
"import requests\n",
|
116 |
+
"import gradio as gr\n",
|
117 |
+
"\n",
|
118 |
+
"# Define the API endpoint\n",
|
119 |
+
"API_URL = \"http://127.0.0.1:8000/trans\"\n",
|
120 |
+
"\n",
|
121 |
+
"# Function to call the FastAPI backend\n",
|
122 |
+
"def predict(user_input):\n",
|
123 |
+
" # Prepare the data to send to the FastAPI API\n",
|
124 |
+
" payload = {\"query\": user_input}\n",
|
125 |
+
" \n",
|
126 |
+
" # Make a request to the FastAPI backend\n",
|
127 |
+
" response = requests.post(API_URL, json=payload)\n",
|
128 |
+
" \n",
|
129 |
+
" # Get the response JSON\n",
|
130 |
+
" result = response.json()\n",
|
131 |
+
" \n",
|
132 |
+
" # Extract the answer \n",
|
133 |
+
" return \" \".join(result[\"response\"])\n",
|
134 |
+
" \n",
|
135 |
+
"\n",
|
136 |
+
"# Launch the Gradio interface\n",
|
137 |
+
"if __name__ == \"__main__\":\n",
|
138 |
+
" gr.Interface(predict,\n",
|
139 |
+
" inputs=['textbox'],\n",
|
140 |
+
" outputs=['text']).launch(share=True)\n"
|
141 |
+
]
|
142 |
+
},
|
143 |
+
{
|
144 |
+
"cell_type": "code",
|
145 |
+
"execution_count": null,
|
146 |
+
"metadata": {},
|
147 |
+
"outputs": [],
|
148 |
+
"source": [
|
149 |
+
"import gradio as gr\n",
|
150 |
+
"\n",
|
151 |
+
"def greet(name, intensity):\n",
|
152 |
+
" return \"Hello, \" + name + \"!\" * int(intensity)\n",
|
153 |
+
"\n",
|
154 |
+
"demo = gr.Interface(\n",
|
155 |
+
" fn=greet,\n",
|
156 |
+
" inputs=[\"text\", \"slider\"],\n",
|
157 |
+
" outputs=[\"text\"],\n",
|
158 |
+
")\n",
|
159 |
+
"\n",
|
160 |
+
"demo.launch()\n"
|
161 |
+
]
|
162 |
+
}
|
163 |
+
],
|
164 |
+
"metadata": {
|
165 |
+
"kernelspec": {
|
166 |
+
"display_name": "transliteration",
|
167 |
+
"language": "python",
|
168 |
+
"name": "python3"
|
169 |
+
},
|
170 |
+
"language_info": {
|
171 |
+
"codemirror_mode": {
|
172 |
+
"name": "ipython",
|
173 |
+
"version": 3
|
174 |
+
},
|
175 |
+
"file_extension": ".py",
|
176 |
+
"mimetype": "text/x-python",
|
177 |
+
"name": "python",
|
178 |
+
"nbconvert_exporter": "python",
|
179 |
+
"pygments_lexer": "ipython3",
|
180 |
+
"version": "3.12.7"
|
181 |
+
}
|
182 |
+
},
|
183 |
+
"nbformat": 4,
|
184 |
+
"nbformat_minor": 2
|
185 |
+
}
|
inference/encoder (1).pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b433380e98834313e590be330b6c2543eb8131ce043922fd95d37def151d1e7c
|
3 |
+
size 9799034
|
inference/input_lang.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9671af04c3790c0ae6183278806e948f2674db3506d3b1a1c65b924a4adbec78
|
3 |
+
size 396
|
inference/language.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class Language:
|
2 |
+
def __init__(self, name):
|
3 |
+
self.name = name
|
4 |
+
self.char2index = {'#': 0, '$': 1, '^': 2} # '^': start of sequence, '$' : unknown char, '#' : padding
|
5 |
+
self.index2char = {0: '#', 1: '$', 2: '^'}
|
6 |
+
self.vocab_size = 3 # Count
|
7 |
+
|
8 |
+
def addWord(self, word):
|
9 |
+
for char in word:
|
10 |
+
self.addChar(char)
|
11 |
+
|
12 |
+
def addChar(self, char):
|
13 |
+
if char not in self.char2index:
|
14 |
+
self.char2index[char] = self.vocab_size
|
15 |
+
self.index2char[self.vocab_size] = char
|
16 |
+
self.vocab_size += 1
|
17 |
+
|
18 |
+
def encode(self, s):
|
19 |
+
return [self.char2index[ch] for ch in s]
|
20 |
+
|
21 |
+
def decode(self, l):
|
22 |
+
return ''.join([self.index2char[i] for i in l])
|
23 |
+
|
24 |
+
def vocab(self):
|
25 |
+
return self.char2index.keys()
|
inference/output_lang.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:412857775d61629580d62df66b7e663bab247467f9a975920df3b1265acd6326
|
3 |
+
size 964
|
inference/transformer.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import pickle
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
from inference.language import Language
|
6 |
+
from inference.utility import Encoder, Decoder, encoderBlock, decoderBlock, MultiHeadAttention, Head, FeedForward
|
7 |
+
import warnings
|
8 |
+
from typing import List
|
9 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
10 |
+
|
11 |
+
|
12 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
+
|
14 |
+
with open(os.path.join(os.path.dirname(__file__), 'input_lang.pkl'), "rb") as file:
|
15 |
+
input_lang = pickle.load(file)
|
16 |
+
|
17 |
+
with open(os.path.join(os.path.dirname(__file__), 'output_lang.pkl'), "rb") as file:
|
18 |
+
output_lang = pickle.load(file)
|
19 |
+
|
20 |
+
encoder = torch.load(os.path.join(os.path.dirname(__file__), 'encoder (1).pth'), map_location=device)
|
21 |
+
decoder = torch.load(os.path.join(os.path.dirname(__file__), 'decoder (1).pth'), map_location=device)
|
22 |
+
|
23 |
+
input_vocab_size = input_lang.vocab_size
|
24 |
+
output_vocab_size = output_lang.vocab_size
|
25 |
+
|
26 |
+
def encode(s):
|
27 |
+
return [input_lang.char2index.get(ch, input_lang.char2index['$']) for ch in s]
|
28 |
+
|
29 |
+
def generate(input: List[str]) -> List[str]:
|
30 |
+
# pre-process the input: same length and max_length = 33
|
31 |
+
for i, inp in enumerate(input):
|
32 |
+
input[i] = input[i][:33] if len(input[i]) > 33 else input[i].ljust(33, '#')
|
33 |
+
|
34 |
+
input = torch.tensor([encode(i) for i in input], device=device, dtype=torch.long)
|
35 |
+
B, T = input.shape
|
36 |
+
|
37 |
+
encoder_output = encoder(input)
|
38 |
+
idx = torch.full((B, 1), 2, dtype=torch.long, device=device) # (B,1)
|
39 |
+
|
40 |
+
# idx is (B, T) array of indices in the current context
|
41 |
+
for _ in range(30):
|
42 |
+
# get the predictions
|
43 |
+
logits, loss = decoder(idx, encoder_output) # logits (B, T, vocab_size)
|
44 |
+
# focus only on the last time step
|
45 |
+
logits = logits[:, -1, :] # becomes (B, C)
|
46 |
+
# apply softmax to get probabilities
|
47 |
+
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (B, 1)
|
48 |
+
# append sampled index to the running sequence
|
49 |
+
idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
|
50 |
+
|
51 |
+
ans = []
|
52 |
+
for id in idx:
|
53 |
+
ans.append(output_lang.decode(id.tolist()[1:]).split('#', 1)[0])
|
54 |
+
return ans
|
inference/utility.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
6 |
+
encoder_block_size = 33
|
7 |
+
decoder_block_size = 30
|
8 |
+
|
9 |
+
class Head(nn.Module):
|
10 |
+
""" one self-attention head """
|
11 |
+
|
12 |
+
def __init__(self, n_embd, d_k, dropout, mask=0): # d_k is dimention of key , nomaly d_k = n_embd / 4
|
13 |
+
super().__init__()
|
14 |
+
self.mask = mask
|
15 |
+
self.key = nn.Linear(n_embd, d_k, bias=False, device=device)
|
16 |
+
self.query = nn.Linear(n_embd, d_k, bias=False, device=device)
|
17 |
+
self.value = nn.Linear(n_embd, d_k, bias=False, device=device)
|
18 |
+
if mask:
|
19 |
+
self.register_buffer('tril', torch.tril(torch.ones(encoder_block_size, encoder_block_size, device=device)))
|
20 |
+
self.dropout = nn.Dropout(dropout)
|
21 |
+
|
22 |
+
def forward(self, x, encoder_output = None):
|
23 |
+
B,T,C = x.shape
|
24 |
+
|
25 |
+
if encoder_output is not None:
|
26 |
+
k = self.key(encoder_output)
|
27 |
+
Be, Te, Ce = encoder_output.shape
|
28 |
+
else:
|
29 |
+
k = self.key(x) # (B,T,d_k)
|
30 |
+
|
31 |
+
q = self.query(x) # (B,T,d_k)
|
32 |
+
# compute attention scores
|
33 |
+
wei = q @ k.transpose(-2, -1) * C**-0.5 # (B,T,T)
|
34 |
+
|
35 |
+
if self.mask:
|
36 |
+
if encoder_output is not None:
|
37 |
+
wei = wei.masked_fill(self.tril[:T, :Te] == 0, float('-inf')) # (B,T,T)
|
38 |
+
else:
|
39 |
+
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B,T,T)
|
40 |
+
|
41 |
+
wei = F.softmax(wei, dim=-1)
|
42 |
+
wei = self.dropout(wei)
|
43 |
+
# perform weighted aggregation of values
|
44 |
+
if encoder_output is not None:
|
45 |
+
v = self.value(encoder_output)
|
46 |
+
else:
|
47 |
+
v = self.value(x)
|
48 |
+
out = wei @ v # (B,T,C)
|
49 |
+
return out
|
50 |
+
|
51 |
+
class MultiHeadAttention(nn.Module):
|
52 |
+
""" multiple self attention heads in parallel """
|
53 |
+
|
54 |
+
def __init__(self, n_embd, num_head, d_k, dropout, mask=0):
|
55 |
+
super().__init__()
|
56 |
+
self.heads = nn.ModuleList([Head(n_embd, d_k, dropout, mask) for _ in range(num_head)])
|
57 |
+
self.proj = nn.Linear(n_embd, n_embd)
|
58 |
+
self.dropout = nn.Dropout(dropout)
|
59 |
+
|
60 |
+
def forward(self, x, encoder_output=None):
|
61 |
+
out = torch.cat([h(x, encoder_output) for h in self.heads], dim=-1)
|
62 |
+
out = self.dropout(self.proj(out))
|
63 |
+
return out
|
64 |
+
|
65 |
+
class FeedForward(nn.Module):
|
66 |
+
""" multiple self attention heads in parallel """
|
67 |
+
|
68 |
+
def __init__(self, n_embd, dropout):
|
69 |
+
super().__init__()
|
70 |
+
self.net = nn.Sequential(
|
71 |
+
nn.Linear(n_embd, 4 * n_embd),
|
72 |
+
nn.ReLU(),
|
73 |
+
nn.Linear(4 * n_embd, n_embd),
|
74 |
+
nn.Dropout(dropout)
|
75 |
+
)
|
76 |
+
|
77 |
+
def forward(self, x):
|
78 |
+
return self.net(x)
|
79 |
+
|
80 |
+
class encoderBlock(nn.Module):
|
81 |
+
""" Tranformer encoder block : communication followed by computation """
|
82 |
+
|
83 |
+
def __init__(self, n_embd, n_head, dropout):
|
84 |
+
super().__init__()
|
85 |
+
d_k = n_embd // n_head
|
86 |
+
self.sa = MultiHeadAttention(n_embd, n_head, d_k, dropout)
|
87 |
+
self.ffwd = FeedForward(n_embd, dropout)
|
88 |
+
self.ln1 = nn.LayerNorm(n_embd)
|
89 |
+
self.ln2 = nn.LayerNorm(n_embd)
|
90 |
+
|
91 |
+
def forward(self, x, encoder_output=None):
|
92 |
+
x = x + self.sa(self.ln1(x), encoder_output)
|
93 |
+
x = x + self.ffwd(self.ln2(x))
|
94 |
+
return x
|
95 |
+
|
96 |
+
class Encoder(nn.Module):
|
97 |
+
|
98 |
+
def __init__(self, n_embd, n_head, n_layers, dropout):
|
99 |
+
super().__init__()
|
100 |
+
|
101 |
+
self.token_embedding_table = nn.Embedding(input_vocab_size, n_embd) # n_embd: input embedding dimension
|
102 |
+
self.position_embedding_table = nn.Embedding(encoder_block_size, n_embd)
|
103 |
+
self.blocks = nn.Sequential(*[encoderBlock(n_embd, n_head, dropout) for _ in range(n_layers)])
|
104 |
+
self.ln_f = nn.LayerNorm(n_embd) # final layer norm
|
105 |
+
|
106 |
+
def forward(self, idx):
|
107 |
+
B, T = idx.shape
|
108 |
+
tok_emb = self.token_embedding_table(idx) # (B,T,n_embd)
|
109 |
+
pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,n_embd)
|
110 |
+
x = tok_emb + pos_emb # (B,T,n_embd)
|
111 |
+
x = self.blocks(x) # apply one attention layer (B,T,C)
|
112 |
+
x = self.ln_f(x) # (B,T,C)
|
113 |
+
return x
|
114 |
+
|
115 |
+
|
116 |
+
class decoderBlock(nn.Module):
|
117 |
+
""" Tranformer decoder block : self communication then cross communication followed by computation """
|
118 |
+
|
119 |
+
def __init__(self, n_embd, n_head, dropout):
|
120 |
+
super().__init__()
|
121 |
+
d_k = n_embd // n_head
|
122 |
+
self.sa = MultiHeadAttention(n_embd, n_head, d_k, dropout, mask = 1)
|
123 |
+
self.ca = MultiHeadAttention(n_embd, n_head, d_k, dropout, mask = 1)
|
124 |
+
self.ffwd = FeedForward(n_embd, dropout)
|
125 |
+
self.ln1 = nn.LayerNorm(n_embd, device=device)
|
126 |
+
self.ln2 = nn.LayerNorm(n_embd, device=device)
|
127 |
+
self.ln3 = nn.LayerNorm(n_embd, device=device)
|
128 |
+
|
129 |
+
def forward(self, x_encoder_output):
|
130 |
+
x = x_encoder_output[0]
|
131 |
+
encoder_output = x_encoder_output[1]
|
132 |
+
x = x + self.sa(self.ln1(x))
|
133 |
+
x = x + self.ca(self.ln2(x), encoder_output)
|
134 |
+
x = x + self.ffwd(self.ln3(x))
|
135 |
+
return (x,encoder_output)
|
136 |
+
|
137 |
+
class Decoder(nn.Module):
|
138 |
+
|
139 |
+
def __init__(self, n_embd, n_head, n_layers, dropout):
|
140 |
+
super().__init__()
|
141 |
+
|
142 |
+
self.token_embedding_table = nn.Embedding(output_vocab_size, n_embd) # n_embd: input embedding dimension
|
143 |
+
self.position_embedding_table = nn.Embedding(decoder_block_size, n_embd)
|
144 |
+
self.blocks = nn.Sequential(*[decoderBlock(n_embd, n_head=n_head, dropout=dropout) for _ in range(n_layers)])
|
145 |
+
self.ln_f = nn.LayerNorm(n_embd) # final layer norm
|
146 |
+
self.lm_head = nn.Linear(n_embd, output_vocab_size)
|
147 |
+
|
148 |
+
def forward(self, idx, encoder_output, targets=None):
|
149 |
+
B, T = idx.shape
|
150 |
+
|
151 |
+
tok_emb = self.token_embedding_table(idx) # (B,T,n_embd)
|
152 |
+
pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,n_embd)
|
153 |
+
x = tok_emb + pos_emb # (B,T,n_embd)
|
154 |
+
|
155 |
+
x =self.blocks((x, encoder_output))
|
156 |
+
x = self.ln_f(x[0]) # (B,T,C)
|
157 |
+
logits = self.lm_head(x) # (B,T,output_vocab_size)
|
158 |
+
|
159 |
+
if targets is None:
|
160 |
+
loss = None
|
161 |
+
else:
|
162 |
+
B, T, C = logits.shape
|
163 |
+
temp_logits = logits.view(B*T, C)
|
164 |
+
targets = targets.reshape(B*T)
|
165 |
+
|
166 |
+
loss = F.cross_entropy(temp_logits, targets.long())
|
167 |
+
|
168 |
+
# print(logits)
|
169 |
+
# out = torch.argmax(logits)
|
170 |
+
|
171 |
+
return logits, loss
|
172 |
+
|
main.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import gradio as gr
|
3 |
+
|
4 |
+
# Define the API endpoint
|
5 |
+
API_URL = "http://127.0.0.1:8000/trans"
|
6 |
+
|
7 |
+
# Function to call the FastAPI backend
|
8 |
+
def predict(user_input):
|
9 |
+
# Prepare the data to send to the FastAPI API
|
10 |
+
payload = {"query": user_input}
|
11 |
+
|
12 |
+
# Make a request to the FastAPI backend
|
13 |
+
response = requests.post(API_URL, json=payload)
|
14 |
+
|
15 |
+
# Get the response JSON
|
16 |
+
result = response.json()
|
17 |
+
|
18 |
+
# Extract the answer
|
19 |
+
return " ".join(result["response"])
|
20 |
+
|
21 |
+
|
22 |
+
# Launch the Gradio interface
|
23 |
+
if __name__ == "__main__":
|
24 |
+
gr.Interface(predict,
|
25 |
+
inputs=['textbox'],
|
26 |
+
outputs=['text']).launch(share=True)
|
notebooks/encoder_decoder_RNNs.ipynb
ADDED
@@ -0,0 +1,1924 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"colab_type": "text",
|
7 |
+
"id": "view-in-github"
|
8 |
+
},
|
9 |
+
"source": [
|
10 |
+
"<a href=\"https://colab.research.google.com/github/pankajrawat9075/CS6910_assignment_3/blob/main/DL_PA3_final.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "markdown",
|
15 |
+
"metadata": {
|
16 |
+
"id": "hRdpoWePeYHn"
|
17 |
+
},
|
18 |
+
"source": [
|
19 |
+
"## Importing Libraries and models"
|
20 |
+
]
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"cell_type": "code",
|
24 |
+
"execution_count": null,
|
25 |
+
"metadata": {
|
26 |
+
"id": "0LBvFtYGCNgJ"
|
27 |
+
},
|
28 |
+
"outputs": [],
|
29 |
+
"source": [
|
30 |
+
"%%capture\n",
|
31 |
+
"!pip install wandb"
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "code",
|
36 |
+
"execution_count": null,
|
37 |
+
"metadata": {
|
38 |
+
"id": "zkZTzr7OCPBM"
|
39 |
+
},
|
40 |
+
"outputs": [],
|
41 |
+
"source": [
|
42 |
+
"import wandb"
|
43 |
+
]
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"cell_type": "code",
|
47 |
+
"execution_count": null,
|
48 |
+
"metadata": {
|
49 |
+
"id": "z4ZVrIumZcDt"
|
50 |
+
},
|
51 |
+
"outputs": [],
|
52 |
+
"source": [
|
53 |
+
"from __future__ import unicode_literals, print_function, division\n",
|
54 |
+
"from io import open\n",
|
55 |
+
"import unicodedata\n",
|
56 |
+
"import string\n",
|
57 |
+
"import re\n",
|
58 |
+
"import random\n",
|
59 |
+
"import pandas as pd\n",
|
60 |
+
"import torch\n",
|
61 |
+
"import torch.nn as nn\n",
|
62 |
+
"from torch import optim\n",
|
63 |
+
"import torch.nn.functional as F\n",
|
64 |
+
"\n",
|
65 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
66 |
+
"torch.cuda.empty_cache()"
|
67 |
+
]
|
68 |
+
},
|
69 |
+
{
|
70 |
+
"cell_type": "code",
|
71 |
+
"execution_count": null,
|
72 |
+
"metadata": {
|
73 |
+
"colab": {
|
74 |
+
"base_uri": "https://localhost:8080/"
|
75 |
+
},
|
76 |
+
"id": "qwL09v65CIse",
|
77 |
+
"outputId": "f1dcbc80-5110-48f9-d0c5-836a2daa05b4"
|
78 |
+
},
|
79 |
+
"outputs": [
|
80 |
+
{
|
81 |
+
"name": "stdout",
|
82 |
+
"output_type": "stream",
|
83 |
+
"text": [
|
84 |
+
"cuda\n"
|
85 |
+
]
|
86 |
+
}
|
87 |
+
],
|
88 |
+
"source": [
|
89 |
+
"print(device)"
|
90 |
+
]
|
91 |
+
},
|
92 |
+
{
|
93 |
+
"cell_type": "markdown",
|
94 |
+
"metadata": {
|
95 |
+
"id": "44xIRolL_T_d"
|
96 |
+
},
|
97 |
+
"source": [
|
98 |
+
"## Load Dataset"
|
99 |
+
]
|
100 |
+
},
|
101 |
+
{
|
102 |
+
"cell_type": "code",
|
103 |
+
"execution_count": null,
|
104 |
+
"metadata": {
|
105 |
+
"colab": {
|
106 |
+
"base_uri": "https://localhost:8080/"
|
107 |
+
},
|
108 |
+
"id": "-XRMpx9eBzRK",
|
109 |
+
"outputId": "177ee7ae-bb7d-46ea-9269-fa3aa045a89e"
|
110 |
+
},
|
111 |
+
"outputs": [
|
112 |
+
{
|
113 |
+
"name": "stdout",
|
114 |
+
"output_type": "stream",
|
115 |
+
"text": [
|
116 |
+
"Mounted at /content/drive\n"
|
117 |
+
]
|
118 |
+
}
|
119 |
+
],
|
120 |
+
"source": [
|
121 |
+
"from google.colab import drive\n",
|
122 |
+
"drive.mount('/content/drive')"
|
123 |
+
]
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"cell_type": "code",
|
127 |
+
"execution_count": null,
|
128 |
+
"metadata": {
|
129 |
+
"id": "Y4zemXiyE6Fi"
|
130 |
+
},
|
131 |
+
"outputs": [],
|
132 |
+
"source": [
|
133 |
+
"class Lang:\n",
|
134 |
+
" def __init__(self, name):\n",
|
135 |
+
" self.name = name\n",
|
136 |
+
" self.char2index = {'#': 0, '$': 1, '^': 2}\n",
|
137 |
+
" self.char2count = {'#': 1, '$': 1, '^': 1}\n",
|
138 |
+
" self.index2char = {0: '#', 1: '$', 2: '^'}\n",
|
139 |
+
" self.n_chars = 3 # Count\n",
|
140 |
+
" self.data = {}\n",
|
141 |
+
" \n",
|
142 |
+
"\n",
|
143 |
+
" def addWord(self, word):\n",
|
144 |
+
" for char in word:\n",
|
145 |
+
" self.addChar(char)\n",
|
146 |
+
"\n",
|
147 |
+
" def addChar(self, char):\n",
|
148 |
+
" if char not in self.char2index:\n",
|
149 |
+
" self.char2index[char] = self.n_chars\n",
|
150 |
+
" self.char2count[char] = 1\n",
|
151 |
+
" self.index2char[self.n_chars] = char\n",
|
152 |
+
" self.n_chars += 1\n",
|
153 |
+
" else:\n",
|
154 |
+
" self.char2count[char] += 1\n",
|
155 |
+
"\n",
|
156 |
+
" \n"
|
157 |
+
]
|
158 |
+
},
|
159 |
+
{
|
160 |
+
"cell_type": "code",
|
161 |
+
"execution_count": null,
|
162 |
+
"metadata": {
|
163 |
+
"id": "dCR658yRvXpy"
|
164 |
+
},
|
165 |
+
"outputs": [],
|
166 |
+
"source": [
|
167 |
+
"# return max length of input and output words\n",
|
168 |
+
"def maxLength(data):\n",
|
169 |
+
" ip_mlen, op_mlen = 0, 0\n",
|
170 |
+
"\n",
|
171 |
+
" for i in range(len(data)):\n",
|
172 |
+
" input = data[0][i]\n",
|
173 |
+
" output = data[1][i]\n",
|
174 |
+
" if(len(input)>ip_mlen):\n",
|
175 |
+
" ip_mlen=len(input)\n",
|
176 |
+
"\n",
|
177 |
+
" if(len(output)>op_mlen):\n",
|
178 |
+
" op_mlen=len(output)\n",
|
179 |
+
"\n",
|
180 |
+
" return ip_mlen, op_mlen"
|
181 |
+
]
|
182 |
+
},
|
183 |
+
{
|
184 |
+
"cell_type": "code",
|
185 |
+
"execution_count": null,
|
186 |
+
"metadata": {
|
187 |
+
"id": "IDGaCO8DkYpc"
|
188 |
+
},
|
189 |
+
"outputs": [],
|
190 |
+
"source": [
|
191 |
+
"import numpy\n",
|
192 |
+
"input_shape = 0\n",
|
193 |
+
"from torch.utils.data import TensorDataset, DataLoader\n",
|
194 |
+
"def preprocess(data, input_lang, output_lang):\n",
|
195 |
+
" maxlenInput, maxlenOutput = maxLength(data)\n",
|
196 |
+
" # we use maxlenInput as 26 since it is the maximum of all input len\n",
|
197 |
+
" maxlenInput = 26\n",
|
198 |
+
" input = numpy.zeros((len(data), maxlenInput + 1))\n",
|
199 |
+
" output = numpy.zeros((len(data), maxlenOutput + 2))\n",
|
200 |
+
" maxlenInput, maxlenOutput = maxLength(data)\n",
|
201 |
+
" unknown = input_lang.char2index['$']\n",
|
202 |
+
"\n",
|
203 |
+
" for i in range(len(data)):\n",
|
204 |
+
" op = '^' + data[1][i]\n",
|
205 |
+
" ip = data[0][i].ljust(maxlenInput + 1, '#')\n",
|
206 |
+
" op = op.ljust(maxlenOutput + 2, '#')\n",
|
207 |
+
" \n",
|
208 |
+
"\n",
|
209 |
+
" for index, char in enumerate(ip):\n",
|
210 |
+
" if input_lang.char2index.get(char) is not None:\n",
|
211 |
+
" input[i][index] = input_lang.char2index[char]\n",
|
212 |
+
" else:\n",
|
213 |
+
" input[i][index] = unknown\n",
|
214 |
+
" \n",
|
215 |
+
"\n",
|
216 |
+
" \n",
|
217 |
+
" for index, char in enumerate(op):\n",
|
218 |
+
" if output_lang.char2index.get(char) is not None:\n",
|
219 |
+
" output[i][index] = output_lang.char2index[char]\n",
|
220 |
+
" else:\n",
|
221 |
+
" output[i][index] = unknown \n",
|
222 |
+
"\n",
|
223 |
+
" print(input.shape)\n",
|
224 |
+
" print(output.shape)\n",
|
225 |
+
"\n",
|
226 |
+
" return TensorDataset(torch.from_numpy(input), torch.from_numpy(output))"
|
227 |
+
]
|
228 |
+
},
|
229 |
+
{
|
230 |
+
"cell_type": "code",
|
231 |
+
"execution_count": null,
|
232 |
+
"metadata": {
|
233 |
+
"colab": {
|
234 |
+
"base_uri": "https://localhost:8080/"
|
235 |
+
},
|
236 |
+
"id": "PdS5OXKxfdCX",
|
237 |
+
"outputId": "178f1d73-5b0c-431d-ca9b-d9435b924c41"
|
238 |
+
},
|
239 |
+
"outputs": [
|
240 |
+
{
|
241 |
+
"name": "stdout",
|
242 |
+
"output_type": "stream",
|
243 |
+
"text": [
|
244 |
+
"(51200, 27)\n",
|
245 |
+
"(51200, 22)\n",
|
246 |
+
"(4096, 27)\n",
|
247 |
+
"(4096, 22)\n",
|
248 |
+
"(4096, 27)\n",
|
249 |
+
"(4096, 22)\n"
|
250 |
+
]
|
251 |
+
}
|
252 |
+
],
|
253 |
+
"source": [
|
254 |
+
"def loadData(lang):\n",
|
255 |
+
" train_df = pd.read_csv(f\"drive/MyDrive/aksharantar_sampled/{lang}/{lang}_train.csv\", header = None)\n",
|
256 |
+
" val_df = pd.read_csv(f\"drive/MyDrive/aksharantar_sampled/{lang}/{lang}_valid.csv\", header = None)\n",
|
257 |
+
" test_df = pd.read_csv(f\"drive/MyDrive/aksharantar_sampled/{lang}/{lang}_test.csv\", header = None)\n",
|
258 |
+
"\n",
|
259 |
+
" input_lang = Lang('eng')\n",
|
260 |
+
" output_lang = Lang(lang)\n",
|
261 |
+
" \n",
|
262 |
+
" # add the words to the respective languages\n",
|
263 |
+
" for i in range(len(train_df)):\n",
|
264 |
+
" \n",
|
265 |
+
" input_lang.addWord(train_df[0][i])\n",
|
266 |
+
" output_lang.addWord(train_df[1][i])\n",
|
267 |
+
"\n",
|
268 |
+
" # print(input_lang.char2index)\n",
|
269 |
+
" # print(input_lang.index2char)\n",
|
270 |
+
" trainDataset = preprocess(train_df, input_lang, output_lang)\n",
|
271 |
+
" testDataset = preprocess(test_df, input_lang, output_lang)\n",
|
272 |
+
" valDataset = preprocess(val_df, input_lang, output_lang)\n",
|
273 |
+
"\n",
|
274 |
+
" return trainDataset, testDataset, valDataset, input_lang, output_lang\n",
|
275 |
+
"\n",
|
276 |
+
"\n",
|
277 |
+
"trainData, testData, valData, ipLang, opLang = loadData('hin')\n"
|
278 |
+
]
|
279 |
+
},
|
280 |
+
{
|
281 |
+
"cell_type": "code",
|
282 |
+
"execution_count": null,
|
283 |
+
"metadata": {
|
284 |
+
"colab": {
|
285 |
+
"base_uri": "https://localhost:8080/"
|
286 |
+
},
|
287 |
+
"id": "SvmzS5Lt_Jnl",
|
288 |
+
"outputId": "33defb60-5aee-46cb-e683-ee2df9e98436"
|
289 |
+
},
|
290 |
+
"outputs": [
|
291 |
+
{
|
292 |
+
"name": "stderr",
|
293 |
+
"output_type": "stream",
|
294 |
+
"text": [
|
295 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: W&B API key is configured. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n",
|
296 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m If you're specifying your api key in code, ensure this code is not shared publicly.\n",
|
297 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Consider setting the WANDB_API_KEY environment variable, or running `wandb login` from the command line.\n",
|
298 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n"
|
299 |
+
]
|
300 |
+
},
|
301 |
+
{
|
302 |
+
"data": {
|
303 |
+
"text/plain": [
|
304 |
+
"True"
|
305 |
+
]
|
306 |
+
},
|
307 |
+
"execution_count": 10,
|
308 |
+
"metadata": {},
|
309 |
+
"output_type": "execute_result"
|
310 |
+
}
|
311 |
+
],
|
312 |
+
"source": [
|
313 |
+
"wandb.login(key =\"\")"
|
314 |
+
]
|
315 |
+
},
|
316 |
+
{
|
317 |
+
"cell_type": "markdown",
|
318 |
+
"metadata": {
|
319 |
+
"id": "Q1TioafYgICa"
|
320 |
+
},
|
321 |
+
"source": [
|
322 |
+
"# seq2seq model"
|
323 |
+
]
|
324 |
+
},
|
325 |
+
{
|
326 |
+
"cell_type": "markdown",
|
327 |
+
"metadata": {
|
328 |
+
"id": "svxssm9Havhb"
|
329 |
+
},
|
330 |
+
"source": [
|
331 |
+
"## Encoder"
|
332 |
+
]
|
333 |
+
},
|
334 |
+
{
|
335 |
+
"cell_type": "code",
|
336 |
+
"execution_count": null,
|
337 |
+
"metadata": {
|
338 |
+
"id": "YTwk8nKNcbkb"
|
339 |
+
},
|
340 |
+
"outputs": [],
|
341 |
+
"source": [
|
342 |
+
"class EncoderRNN(nn.Module):\n",
|
343 |
+
" def __init__(self, input_size, hidden_size, embedding_size, # input_size is size of input language dictionary\n",
|
344 |
+
" num_layers, cell_type,\n",
|
345 |
+
" bidirectional, dropout, batch_size) :\n",
|
346 |
+
" super(EncoderRNN, self).__init__()\n",
|
347 |
+
" self.hidden_size = hidden_size # size of an hidden state representation\n",
|
348 |
+
" self.num_layers = num_layers \n",
|
349 |
+
" self.bidirectional = True if bidirectional == 'Yes' else False\n",
|
350 |
+
" self.batch_size = batch_size\n",
|
351 |
+
" self.cell_type = cell_type\n",
|
352 |
+
" self.embedding_size=embedding_size\n",
|
353 |
+
"\n",
|
354 |
+
" # this adds the embedding layer\n",
|
355 |
+
" self.embedding = nn.Embedding(num_embeddings=input_size,embedding_dim= embedding_size)\n",
|
356 |
+
" self.dropout = nn.Dropout(dropout)\n",
|
357 |
+
"\n",
|
358 |
+
" # this adds the Neural Network layer for the encoder\n",
|
359 |
+
" if self.cell_type == \"GRU\":\n",
|
360 |
+
" self.rnn = nn.GRU(embedding_size, hidden_size, num_layers=num_layers, bidirectional=self.bidirectional, dropout=dropout)\n",
|
361 |
+
" elif self.cell_type == \"LSTM\":\n",
|
362 |
+
" self.rnn = nn.LSTM(embedding_size, hidden_size, num_layers=num_layers, bidirectional=self.bidirectional, dropout=dropout)\n",
|
363 |
+
" else:\n",
|
364 |
+
" self.rnn = nn.RNN(embedding_size, hidden_size, num_layers=num_layers, bidirectional=self.bidirectional, dropout=dropout)\n",
|
365 |
+
"\n",
|
366 |
+
" def forward(self, input, hidden): # input shape (seq_len, batch_size) hidden shape tuple for lstm, otherwise single\n",
|
367 |
+
" embedded = self.embedding(input.long()).view(-1,self.batch_size, self.embedding_size)\n",
|
368 |
+
" output = self.dropout(embedded) # output shape (seq_len, batch_size, embedding size)\n",
|
369 |
+
"\n",
|
370 |
+
" output, hidden = self.rnn(output, hidden) # for LSTM hidden is a tuple\n",
|
371 |
+
" if self.bidirectional:\n",
|
372 |
+
" if self.cell_type == \"LSTM\":\n",
|
373 |
+
" hidden_state = hidden[0].resize(2,self.num_layers,self.batch_size,self.hidden_size)\n",
|
374 |
+
" cell_state = hidden[1].resize(2,self.num_layers,self.batch_size,self.hidden_size)\n",
|
375 |
+
" hidden = (torch.add(hidden_state[0],hidden_state[1])/2, torch.add(cell_state[0],cell_state[1])/2)\n",
|
376 |
+
" else:\n",
|
377 |
+
" hidden=hidden.resize(2,self.num_layers,self.batch_size,self.hidden_size)\n",
|
378 |
+
" hidden=torch.add(hidden[0],hidden[1])/2\n",
|
379 |
+
" \n",
|
380 |
+
" split_tensor= torch.split(output, self.hidden_size, dim=-1)\n",
|
381 |
+
" output=torch.add(split_tensor[0],split_tensor[1])/2\n",
|
382 |
+
" return output, hidden\n",
|
383 |
+
"\n",
|
384 |
+
" # initializing the initial hidden state for the encoder\n",
|
385 |
+
" def initHidden(self):\n",
|
386 |
+
" num_directions = 2 if self.bidirectional else 1\n",
|
387 |
+
" if self.cell_type == \"LSTM\":\n",
|
388 |
+
" return (torch.zeros(self.num_layers * num_directions, self.batch_size, self.hidden_size, device=device),\n",
|
389 |
+
" torch.zeros(self.num_layers * num_directions, self.batch_size, self.hidden_size, device=device))\n",
|
390 |
+
" else:\n",
|
391 |
+
" return torch.zeros(self.num_layers * num_directions, self.batch_size, self.hidden_size, device=device)\n"
|
392 |
+
]
|
393 |
+
},
|
394 |
+
{
|
395 |
+
"cell_type": "markdown",
|
396 |
+
"metadata": {
|
397 |
+
"id": "J56aq1J6a07q"
|
398 |
+
},
|
399 |
+
"source": [
|
400 |
+
"## Decoder"
|
401 |
+
]
|
402 |
+
},
|
403 |
+
{
|
404 |
+
"cell_type": "code",
|
405 |
+
"execution_count": null,
|
406 |
+
"metadata": {
|
407 |
+
"id": "53ki6eJUH2u2"
|
408 |
+
},
|
409 |
+
"outputs": [],
|
410 |
+
"source": [
|
411 |
+
"class DecoderRNN(nn.Module):\n",
|
412 |
+
" def __init__(self, hidden_size, output_size, embedding_size, num_layers, # output size is the size of output language dictionary\n",
|
413 |
+
" cell_type, dropout, batch_size):\n",
|
414 |
+
" super(DecoderRNN, self).__init__()\n",
|
415 |
+
" self.hidden_size = hidden_size\n",
|
416 |
+
" self.num_layers = num_layers\n",
|
417 |
+
" self.cell_type = cell_type.lower()\n",
|
418 |
+
" self.batch_size = batch_size\n",
|
419 |
+
" self.embedding_size=embedding_size\n",
|
420 |
+
"\n",
|
421 |
+
" self.embedding = nn.Embedding(output_size, embedding_size)\n",
|
422 |
+
" # self.dropout = nn.Dropout(dropout)\n",
|
423 |
+
" \n",
|
424 |
+
" if self.cell_type == \"gru\":\n",
|
425 |
+
" self.rnn = nn.GRU(embedding_size, hidden_size, num_layers=num_layers)\n",
|
426 |
+
" elif self.cell_type == \"lstm\":\n",
|
427 |
+
" self.rnn = nn.LSTM(embedding_size, hidden_size, num_layers=num_layers)\n",
|
428 |
+
" else:\n",
|
429 |
+
" self.rnn = nn.RNN(embedding_size, hidden_size, num_layers=num_layers)\n",
|
430 |
+
"\n",
|
431 |
+
" self.out = nn.Linear(hidden_size, output_size)\n",
|
432 |
+
" self.softmax = nn.LogSoftmax(dim=2)\n",
|
433 |
+
"\n",
|
434 |
+
" def forward(self, input, hidden): # input shape (1, batch_size)\n",
|
435 |
+
" embedded = self.embedding(input.long()).view(-1, self.batch_size, self.embedding_size)\n",
|
436 |
+
" # # shape (1, batch_size, embedding_size)\n",
|
437 |
+
" output = F.relu(embedded)\n",
|
438 |
+
" output, hidden = self.rnn(output, hidden) # output shape (1, batch_size, hidden_size)\n",
|
439 |
+
" output = self.softmax(self.out(output)) # shape (1, batch_size, output_size)\n",
|
440 |
+
" return output, hidden\n",
|
441 |
+
"\n",
|
442 |
+
" # not needed since hidden will be provided by the encoder"
|
443 |
+
]
|
444 |
+
},
|
445 |
+
{
|
446 |
+
"cell_type": "markdown",
|
447 |
+
"metadata": {
|
448 |
+
"id": "5JcQdylzI_Fc"
|
449 |
+
},
|
450 |
+
"source": [
|
451 |
+
"## Attention Decoder"
|
452 |
+
]
|
453 |
+
},
|
454 |
+
{
|
455 |
+
"cell_type": "code",
|
456 |
+
"execution_count": null,
|
457 |
+
"metadata": {
|
458 |
+
"id": "R1Xysuv9I-Qr"
|
459 |
+
},
|
460 |
+
"outputs": [],
|
461 |
+
"source": [
|
462 |
+
"class AttentionDecoderRNN(nn.Module):\n",
|
463 |
+
" def __init__(self, hidden_size, output_size, embedding_size, num_layers,\n",
|
464 |
+
" cell_type, dropout, batch_size, max_length):\n",
|
465 |
+
" super(AttentionDecoderRNN, self).__init__()\n",
|
466 |
+
" self.hidden_size = hidden_size\n",
|
467 |
+
" self.num_layers = num_layers\n",
|
468 |
+
" self.cell_type = cell_type\n",
|
469 |
+
" self.batch_size = batch_size\n",
|
470 |
+
" self.embedding_size = embedding_size\n",
|
471 |
+
" self.max_length = max_length\n",
|
472 |
+
" self.dropout = dropout\n",
|
473 |
+
"\n",
|
474 |
+
" self.embedding = nn.Embedding(output_size, embedding_size)\n",
|
475 |
+
" self.dropout = nn.Dropout(self.dropout)\n",
|
476 |
+
" self.attention = nn.Linear(hidden_size + embedding_size, self.max_length)\n",
|
477 |
+
" self.attention_combine = nn.Linear(hidden_size + embedding_size, hidden_size)\n",
|
478 |
+
"\n",
|
479 |
+
" if self.cell_type == \"GRU\":\n",
|
480 |
+
" self.rnn = nn.GRU(hidden_size, hidden_size, num_layers=num_layers)\n",
|
481 |
+
" elif self.cell_type == \"LSTM\":\n",
|
482 |
+
" self.rnn = nn.LSTM(hidden_size, hidden_size, num_layers=num_layers)\n",
|
483 |
+
" else:\n",
|
484 |
+
" self.rnn = nn.RNN(hidden_size, hidden_size, num_layers=num_layers)\n",
|
485 |
+
"\n",
|
486 |
+
" self.out = nn.Linear(hidden_size, output_size)\n",
|
487 |
+
" self.softmax = nn.LogSoftmax(dim=2)\n",
|
488 |
+
"\n",
|
489 |
+
" def forward(self, input, hidden, encoder_outputs): #input shape (1, batch_size)\n",
|
490 |
+
" embedded = self.embedding(input.long()).view(-1, self.batch_size, self.embedding_size) \n",
|
491 |
+
" # embedded shape (1, batch_size, embedding_size)\n",
|
492 |
+
" embedded = F.relu(embedded)\n",
|
493 |
+
"\n",
|
494 |
+
" # Compute attention scores\n",
|
495 |
+
" if self.cell_type == \"LSTM\":\n",
|
496 |
+
" attn_hidden = torch.mean(hidden[0], dim=0)\n",
|
497 |
+
" else:\n",
|
498 |
+
" attn_hidden = torch.mean(hidden, dim = 0)\n",
|
499 |
+
" attn_scores = self.attention(torch.cat((embedded, attn_hidden.unsqueeze(0)), dim=2)) # attn_scores shape (1, batch_size, max_length)\n",
|
500 |
+
" \n",
|
501 |
+
" attn_weights = F.softmax(attn_scores, dim=-1) # attn_scores shape (1, 16, 25)\n",
|
502 |
+
" \n",
|
503 |
+
"\n",
|
504 |
+
" # Apply attention weights to encoder outputs\n",
|
505 |
+
" attn_applied = torch.bmm(attn_weights.transpose(0, 1), encoder_outputs.transpose(0, 1))\n",
|
506 |
+
" \n",
|
507 |
+
" # Combine attention output and embedded input\n",
|
508 |
+
" combined = torch.cat((embedded, attn_applied.transpose(0, 1)), dim=2)\n",
|
509 |
+
" combined = self.attention_combine(combined)\n",
|
510 |
+
" combined = F.relu(combined) # shape (1, batch_size, hidden_size)\n",
|
511 |
+
"\n",
|
512 |
+
" # Run through the RNN\n",
|
513 |
+
" output, hidden = self.rnn(combined, hidden)\n",
|
514 |
+
" # output shape: (1, batch_size, hidden_size)\n",
|
515 |
+
"\n",
|
516 |
+
" # Pass through linear layer and softmax activation\n",
|
517 |
+
" output = self.out(output) # shape: (1, batch_size, output_size)\n",
|
518 |
+
" output = self.softmax(output)\n",
|
519 |
+
" return output, hidden, attn_weights.transpose(0, 1)\n"
|
520 |
+
]
|
521 |
+
},
|
522 |
+
{
|
523 |
+
"cell_type": "code",
|
524 |
+
"execution_count": null,
|
525 |
+
"metadata": {
|
526 |
+
"id": "LJ2Papj_jTX8"
|
527 |
+
},
|
528 |
+
"outputs": [],
|
529 |
+
"source": []
|
530 |
+
},
|
531 |
+
{
|
532 |
+
"cell_type": "markdown",
|
533 |
+
"metadata": {
|
534 |
+
"id": "658W9RARGEUf"
|
535 |
+
},
|
536 |
+
"source": [
|
537 |
+
"# Helper functions"
|
538 |
+
]
|
539 |
+
},
|
540 |
+
{
|
541 |
+
"cell_type": "markdown",
|
542 |
+
"metadata": {
|
543 |
+
"id": "q7fAgs5uQni_"
|
544 |
+
},
|
545 |
+
"source": [
|
546 |
+
"## count matches"
|
547 |
+
]
|
548 |
+
},
|
549 |
+
{
|
550 |
+
"cell_type": "code",
|
551 |
+
"execution_count": null,
|
552 |
+
"metadata": {
|
553 |
+
"id": "8fzy8U6_lbug"
|
554 |
+
},
|
555 |
+
"outputs": [],
|
556 |
+
"source": [
|
557 |
+
"def count_exact_matches(pred, target):\n",
|
558 |
+
" \"\"\"\n",
|
559 |
+
" Counts the number of rows in preds tensor that match exactly with each row in y tensor.\n",
|
560 |
+
" pred: tensor of shape (batch_size, seq_len-1)\n",
|
561 |
+
" y: tensor of shape (batch_size, seq_len-1)\n",
|
562 |
+
" \"\"\"\n",
|
563 |
+
" \n",
|
564 |
+
" count=0;\n",
|
565 |
+
" for i in range(pred.shape[0]):\n",
|
566 |
+
" flag = True\n",
|
567 |
+
" for j in range(pred.shape[1]):\n",
|
568 |
+
" if(target[i][j]!=pred[i][j]):\n",
|
569 |
+
" flag=False\n",
|
570 |
+
" break;\n",
|
571 |
+
" \n",
|
572 |
+
" if(flag):\n",
|
573 |
+
" count+=1;\n",
|
574 |
+
" \n",
|
575 |
+
" return count"
|
576 |
+
]
|
577 |
+
},
|
578 |
+
{
|
579 |
+
"cell_type": "markdown",
|
580 |
+
"metadata": {
|
581 |
+
"id": "n4rGh7vuQqaa"
|
582 |
+
},
|
583 |
+
"source": [
|
584 |
+
"## evaluation"
|
585 |
+
]
|
586 |
+
},
|
587 |
+
{
|
588 |
+
"cell_type": "code",
|
589 |
+
"execution_count": null,
|
590 |
+
"metadata": {
|
591 |
+
"id": "zp6gvWmDlWoB"
|
592 |
+
},
|
593 |
+
"outputs": [],
|
594 |
+
"source": [
|
595 |
+
"def evaluate(data,encoder, decoder,output_size,batch_size,hidden_size,num_layers_encoder,num_layers_decoder, cell_type, attention):\n",
|
596 |
+
" \n",
|
597 |
+
"\n",
|
598 |
+
"\n",
|
599 |
+
" running_loss = 0\n",
|
600 |
+
" correct =0\n",
|
601 |
+
" \n",
|
602 |
+
" loader = DataLoader(data, batch_size=batch_size)\n",
|
603 |
+
" loss_fun = nn.CrossEntropyLoss(reduction=\"sum\")\n",
|
604 |
+
" seq_len = 0\n",
|
605 |
+
"\n",
|
606 |
+
" atten_weights = torch.zeros(1,21, 27).to(device) # required to return the attention weights\n",
|
607 |
+
" predictions = torch.zeros(22-1, 1).to(device)\n",
|
608 |
+
" with torch.no_grad():\n",
|
609 |
+
" for j,(x,y) in enumerate(loader):\n",
|
610 |
+
" loss=0\n",
|
611 |
+
" encoder.eval()\n",
|
612 |
+
" decoder.eval()\n",
|
613 |
+
"\n",
|
614 |
+
" x = x.to(device)\n",
|
615 |
+
" y = y.to(device)\n",
|
616 |
+
"\n",
|
617 |
+
" x = x.T\n",
|
618 |
+
" y = y.T\n",
|
619 |
+
" seq_len = len(y)\n",
|
620 |
+
" \n",
|
621 |
+
" encoder_hidden=encoder.initHidden()\n",
|
622 |
+
" encoder_output,encoder_hidden = encoder(x,encoder_hidden)\n",
|
623 |
+
" \n",
|
624 |
+
" \n",
|
625 |
+
" decoder_input =y[0]\n",
|
626 |
+
" \n",
|
627 |
+
" # Handle different numbers of layers in the encoder and decoder\n",
|
628 |
+
" if num_layers_encoder != num_layers_decoder:\n",
|
629 |
+
" if num_layers_encoder < num_layers_decoder:\n",
|
630 |
+
" remaining_layers = num_layers_decoder - num_layers_encoder\n",
|
631 |
+
"\n",
|
632 |
+
" # Copy all encoder hidden layers and then repeat the top layer\n",
|
633 |
+
" if cell_type == \"LSTM\":\n",
|
634 |
+
" top_layer_hidden = (encoder_hidden[0][-1].unsqueeze(0), encoder_hidden[1][-1].unsqueeze(0))\n",
|
635 |
+
" extra_hidden = (top_layer_hidden[0].repeat(remaining_layers, 1, 1), top_layer_hidden[1].repeat(remaining_layers, 1, 1))\n",
|
636 |
+
" decoder_hidden = (torch.cat((encoder_hidden[0], extra_hidden[0]), dim=0), torch.cat((encoder_hidden[1], extra_hidden[1]), dim=0))\n",
|
637 |
+
" else:\n",
|
638 |
+
" top_layer_hidden = encoder_hidden[-1].unsqueeze(0) #top_layer_hidden shape (1, batch_size, hidden_size)\n",
|
639 |
+
" extra_hidden = top_layer_hidden.repeat(remaining_layers, 1, 1)\n",
|
640 |
+
" decoder_hidden = torch.cat((encoder_hidden, extra_hidden), dim=0)\n",
|
641 |
+
"\n",
|
642 |
+
" else:\n",
|
643 |
+
" # Slice the hidden states of the encoder to match the decoder layers\n",
|
644 |
+
" if cell_type == \"LSTM\":\n",
|
645 |
+
" decoder_hidden = (encoder_hidden[0][-num_layers_decoder:], encoder_hidden[1][-num_layers_decoder:])\n",
|
646 |
+
" else :\n",
|
647 |
+
" decoder_hidden = encoder_hidden[-num_layers_decoder:]\n",
|
648 |
+
" else:\n",
|
649 |
+
" decoder_hidden = encoder_hidden\n",
|
650 |
+
"\n",
|
651 |
+
" pred=torch.zeros(len(y)-1, batch_size).to(device)\n",
|
652 |
+
" atten_weight_default = torch.zeros(batch_size,1, 27).to(device)\n",
|
653 |
+
" for k in range(1,len(y)):\n",
|
654 |
+
" if attention == \"Yes\":\n",
|
655 |
+
" \n",
|
656 |
+
" decoder_output, decoder_hidden, atten_weight = decoder(decoder_input, decoder_hidden, encoder_output)\n",
|
657 |
+
" atten_weight_default = torch.cat((atten_weight_default, atten_weight), dim = 1)\n",
|
658 |
+
" else:\n",
|
659 |
+
" decoder_output, decoder_hidden= decoder(decoder_input, decoder_hidden)\n",
|
660 |
+
" max_prob, index = decoder_output.topk(1) # max_prob shape (1, batch_size, 1)\n",
|
661 |
+
" decoder_output = torch.squeeze(decoder_output)\n",
|
662 |
+
" loss += loss_fun(decoder_output, y[k].long())\n",
|
663 |
+
" pred[k-1]= torch.squeeze(index)\n",
|
664 |
+
" decoder_input = index\n",
|
665 |
+
" if attention == \"Yes\":\n",
|
666 |
+
" atten_weights = torch.cat((atten_weights, atten_weight_default[:, 1:, :]), dim = 0)\n",
|
667 |
+
"\n",
|
668 |
+
" running_loss += loss.item()\n",
|
669 |
+
" correct += count_exact_matches(pred.T,y[1:,:].T)\n",
|
670 |
+
" predictions = torch.cat((predictions, pred), dim=1)\n",
|
671 |
+
"\n",
|
672 |
+
" \n",
|
673 |
+
" avg_loss = running_loss / (len(data) * seq_len)\n",
|
674 |
+
" print(\"correct =\", correct)\n",
|
675 |
+
" avg_acc = 100 * (correct / (len(data)))\n",
|
676 |
+
" if attention == \"Yes\":\n",
|
677 |
+
" return avg_loss, avg_acc, predictions, atten_weights[1:, :, :]\n",
|
678 |
+
" else:\n",
|
679 |
+
" return avg_loss, avg_acc, predictions\n",
|
680 |
+
" \n",
|
681 |
+
" \n",
|
682 |
+
" "
|
683 |
+
]
|
684 |
+
},
|
685 |
+
{
|
686 |
+
"cell_type": "markdown",
|
687 |
+
"metadata": {
|
688 |
+
"id": "0SsnRWlgQmCI"
|
689 |
+
},
|
690 |
+
"source": [
|
691 |
+
"# Training function"
|
692 |
+
]
|
693 |
+
},
|
694 |
+
{
|
695 |
+
"cell_type": "code",
|
696 |
+
"execution_count": null,
|
697 |
+
"metadata": {
|
698 |
+
"id": "PhDgsZG0QqPW"
|
699 |
+
},
|
700 |
+
"outputs": [],
|
701 |
+
"source": [
|
702 |
+
"def train(sweeps = True, test = False):\n",
|
703 |
+
"\n",
|
704 |
+
" if sweeps == False: \n",
|
705 |
+
" configs = config_defaults # use the default configuration which has the best hyperparameters\n",
|
706 |
+
" else:\n",
|
707 |
+
" wandb.init(config= config_defaults, project='DL_assign_3') # if not test then run wandb sweeps\n",
|
708 |
+
" configs=wandb.config\n",
|
709 |
+
" \n",
|
710 |
+
"\n",
|
711 |
+
" learn_rate = configs['learn_rate']\n",
|
712 |
+
" batch_size = configs['batch_size']\n",
|
713 |
+
" hidden_size = configs['hidden_size']\n",
|
714 |
+
" embedding_size = configs['embedding_size']\n",
|
715 |
+
" num_layers_encoder = configs['num_layers_encoder']\n",
|
716 |
+
" num_layers_decoder = configs['num_layers_decoder']\n",
|
717 |
+
" cell_type = configs['cell_type']\n",
|
718 |
+
" bidirectional = configs['bidirectional']\n",
|
719 |
+
" dropout = configs['dropout']\n",
|
720 |
+
" teach_ratio = configs['teach_ratio']\n",
|
721 |
+
" epochs = configs['epochs']\n",
|
722 |
+
" attention = configs['attention']\n",
|
723 |
+
"\n",
|
724 |
+
" if sweeps:\n",
|
725 |
+
" wandb.run.name='hidden_'+str(hidden_size)+'_batch_'+str(batch_size)+'_embed_size_'+str(embedding_size)+'_dropout_'+str(dropout)+'_cell_'+str(cell_type)\n",
|
726 |
+
"\n",
|
727 |
+
" input_len = ipLang.n_chars\n",
|
728 |
+
" output_len = opLang.n_chars\n",
|
729 |
+
" \n",
|
730 |
+
" encoder = EncoderRNN(input_len, hidden_size, embedding_size, \n",
|
731 |
+
" num_layers_encoder, cell_type,\n",
|
732 |
+
" bidirectional, dropout, batch_size)\n",
|
733 |
+
" \n",
|
734 |
+
" if attention ==\"Yes\":\n",
|
735 |
+
" decoder = AttentionDecoderRNN(hidden_size, output_len, embedding_size, num_layers_decoder, \n",
|
736 |
+
" cell_type, dropout, batch_size, 27)\n",
|
737 |
+
" else:\n",
|
738 |
+
" decoder = DecoderRNN(hidden_size, output_len, embedding_size, num_layers_decoder, \n",
|
739 |
+
" cell_type, dropout, batch_size)#dropout not used\n",
|
740 |
+
" \n",
|
741 |
+
" train_loader = DataLoader(trainData, batch_size=batch_size, shuffle=True)\n",
|
742 |
+
" val_loader = DataLoader(valData, batch_size=batch_size, shuffle=True)\n",
|
743 |
+
"\n",
|
744 |
+
" encoder_optimizer=optim.Adam(encoder.parameters(),learn_rate)\n",
|
745 |
+
" decoder_optimizer=optim.Adam(decoder.parameters(),learn_rate)\n",
|
746 |
+
" loss_fun=nn.CrossEntropyLoss(reduction=\"sum\")\n",
|
747 |
+
"\n",
|
748 |
+
" encoder.to(device)\n",
|
749 |
+
" decoder.to(device)\n",
|
750 |
+
" seq_len = 0\n",
|
751 |
+
"\n",
|
752 |
+
" # Initialize variables for early stopping\n",
|
753 |
+
" best_val_loss = float('inf')\n",
|
754 |
+
" patience = 5\n",
|
755 |
+
" epochs_without_improvement = 0\n",
|
756 |
+
"\n",
|
757 |
+
" for i in range(epochs):\n",
|
758 |
+
" \n",
|
759 |
+
" running_loss = 0.0\n",
|
760 |
+
" train_correct = 0\n",
|
761 |
+
"\n",
|
762 |
+
" encoder.train()\n",
|
763 |
+
" decoder.train()\n",
|
764 |
+
"\n",
|
765 |
+
" for j,(train_x,train_y) in enumerate(train_loader):\n",
|
766 |
+
" train_x = train_x.to(device)\n",
|
767 |
+
" train_y = train_y.to(device)\n",
|
768 |
+
"\n",
|
769 |
+
" encoder_optimizer.zero_grad()\n",
|
770 |
+
" decoder_optimizer.zero_grad()\n",
|
771 |
+
"\n",
|
772 |
+
" train_x=train_x.T\n",
|
773 |
+
" train_y=train_y.T\n",
|
774 |
+
" # print(\"train_x.shapetrain_x.shape)\n",
|
775 |
+
" seq_len = len(train_y)\n",
|
776 |
+
" encoder_hidden=encoder.initHidden()\n",
|
777 |
+
" # for LSTM encoder_hidden shape ((num_layers * num_directions, batch_size,hidden_size),(self.num_layers * num_directions, batch_size, hidden_size))\n",
|
778 |
+
" encoder_output,encoder_hidden = encoder(train_x,encoder_hidden)\n",
|
779 |
+
" # encoder_hidden shape (num_layers, batch_size, hidden_size)\n",
|
780 |
+
" \n",
|
781 |
+
" \n",
|
782 |
+
" # lets move to the decoder\n",
|
783 |
+
" decoder_input = train_y[0] # shape (1, batch_size)\n",
|
784 |
+
" \n",
|
785 |
+
" # Handle different numbers of layers in the encoder and decoder\n",
|
786 |
+
" if num_layers_encoder != num_layers_decoder:\n",
|
787 |
+
" if num_layers_encoder < num_layers_decoder:\n",
|
788 |
+
" remaining_layers = num_layers_decoder - num_layers_encoder\n",
|
789 |
+
" # Copy all encoder hidden layers and then repeat the top layer\n",
|
790 |
+
" if cell_type == \"LSTM\":\n",
|
791 |
+
" top_layer_hidden = (encoder_hidden[0][-1].unsqueeze(0), encoder_hidden[1][-1].unsqueeze(0))\n",
|
792 |
+
" extra_hidden = (top_layer_hidden[0].repeat(remaining_layers, 1, 1), top_layer_hidden[1].repeat(remaining_layers, 1, 1))\n",
|
793 |
+
" decoder_hidden = (torch.cat((encoder_hidden[0], extra_hidden[0]), dim=0), torch.cat((encoder_hidden[1], extra_hidden[1]), dim=0))\n",
|
794 |
+
" else:\n",
|
795 |
+
" top_layer_hidden = encoder_hidden[-1].unsqueeze(0) #top_layer_hidden shape (1, batch_size, hidden_size)\n",
|
796 |
+
" extra_hidden = top_layer_hidden.repeat(remaining_layers, 1, 1)\n",
|
797 |
+
" decoder_hidden = torch.cat((encoder_hidden, extra_hidden), dim=0)\n",
|
798 |
+
" \n",
|
799 |
+
" else:\n",
|
800 |
+
" # Slice the hidden states of the encoder to match the decoder layers\n",
|
801 |
+
" if cell_type == \"LSTM\":\n",
|
802 |
+
" decoder_hidden = (encoder_hidden[0][-num_layers_decoder:], encoder_hidden[1][-num_layers_decoder:])\n",
|
803 |
+
" else :\n",
|
804 |
+
" decoder_hidden = encoder_hidden[-num_layers_decoder:]\n",
|
805 |
+
" else:\n",
|
806 |
+
" decoder_hidden = encoder_hidden\n",
|
807 |
+
" \n",
|
808 |
+
" loss = 0\n",
|
809 |
+
" correct = 0\n",
|
810 |
+
" \n",
|
811 |
+
" for k in range(0, len(train_y)-1):\n",
|
812 |
+
" \n",
|
813 |
+
" if attention == \"Yes\":\n",
|
814 |
+
" decoder_output, decoder_hidden, atten_weights = decoder(decoder_input, decoder_hidden, encoder_output)\n",
|
815 |
+
" else:\n",
|
816 |
+
" decoder_output, decoder_hidden= decoder(decoder_input, decoder_hidden) # decoder_output shape (1, batch_size, output_size)\n",
|
817 |
+
"\n",
|
818 |
+
" max_prob, index = decoder_output.topk(1) # max_prob shape (1, batch_size, 1)\n",
|
819 |
+
" index = torch.squeeze(index) # shape (batch_size)\n",
|
820 |
+
" decoder_output = torch.squeeze(decoder_output)\n",
|
821 |
+
" loss += loss_fun(decoder_output, train_y[k+1].long())\n",
|
822 |
+
" \n",
|
823 |
+
" correct += (index == train_y[k+1]).sum().item()\n",
|
824 |
+
"\n",
|
825 |
+
" # Apply teacher forcing\n",
|
826 |
+
" use_teacher_forcing = True if random.random() < teach_ratio else False\n",
|
827 |
+
"\n",
|
828 |
+
" if use_teacher_forcing:\n",
|
829 |
+
" decoder_input = train_y[k+1]\n",
|
830 |
+
" \n",
|
831 |
+
" else:\n",
|
832 |
+
" decoder_input = index\n",
|
833 |
+
"\n",
|
834 |
+
" running_loss += loss.item()\n",
|
835 |
+
" train_correct += correct\n",
|
836 |
+
" loss.backward()\n",
|
837 |
+
" encoder_optimizer.step()\n",
|
838 |
+
" decoder_optimizer.step()\n",
|
839 |
+
" \n",
|
840 |
+
"\n",
|
841 |
+
" # find train loss and accuracy and print + log to wandb\n",
|
842 |
+
" if attention == \"Yes\":\n",
|
843 |
+
" _, train_accuracy,_, _ = evaluate(trainData,encoder, decoder,output_len,batch_size,hidden_size,num_layers_encoder,num_layers_decoder, cell_type, attention)\n",
|
844 |
+
" else:\n",
|
845 |
+
" _, train_accuracy,_= evaluate(trainData,encoder, decoder,output_len,batch_size,hidden_size,num_layers_encoder,num_layers_decoder, cell_type, attention)\n",
|
846 |
+
" \n",
|
847 |
+
" print(f\"epoch {i}, training loss {running_loss/(len(trainData)* seq_len)}, training accuracy {train_accuracy}\")\n",
|
848 |
+
" if sweeps:\n",
|
849 |
+
" wandb.log({\"epoch\": i, \"train_loss\": running_loss/(len(trainData)* seq_len), \"train_accuracy\": train_accuracy})\n",
|
850 |
+
" \n",
|
851 |
+
" # # find validation loss and accuracy and print + log to wandb\n",
|
852 |
+
" if attention == \"Yes\":\n",
|
853 |
+
" val_loss, val_accuracy,_, _ = evaluate(valData,encoder, decoder,output_len,batch_size,hidden_size,num_layers_encoder,num_layers_decoder, cell_type, attention)\n",
|
854 |
+
" else:\n",
|
855 |
+
" val_loss, val_accuracy,_ = evaluate(valData,encoder, decoder,output_len,batch_size,hidden_size,num_layers_encoder,num_layers_decoder, cell_type, attention)\n",
|
856 |
+
" \n",
|
857 |
+
" print(f\"epoch {i}, validation loss {val_loss}, validation accuracy {val_accuracy}\")\n",
|
858 |
+
" if sweeps:\n",
|
859 |
+
" wandb.log({\"val_loss\": val_loss, \"val_accuracy\": val_accuracy})\n",
|
860 |
+
"\n",
|
861 |
+
" # Check for early stopping\n",
|
862 |
+
" if val_loss < best_val_loss:\n",
|
863 |
+
" best_val_loss = val_loss\n",
|
864 |
+
" epochs_without_improvement = 0\n",
|
865 |
+
" # Save the model weights\n",
|
866 |
+
" torch.save(encoder.state_dict(), 'best_encoder.pt')\n",
|
867 |
+
" torch.save(decoder.state_dict(), 'best_decoder.pt')\n",
|
868 |
+
" else:\n",
|
869 |
+
" epochs_without_improvement += 1\n",
|
870 |
+
" if epochs_without_improvement >= patience:\n",
|
871 |
+
" print(\"Early stopping triggered. No improvement in validation loss.\")\n",
|
872 |
+
" break\n",
|
873 |
+
" \n",
|
874 |
+
" \n",
|
875 |
+
" # if testing mode is on print the test accuracy \n",
|
876 |
+
" if test:\n",
|
877 |
+
" # Load the best model weights\n",
|
878 |
+
" encoder.load_state_dict(torch.load('best_encoder.pt'))\n",
|
879 |
+
" decoder.load_state_dict(torch.load('best_decoder.pt'))\n",
|
880 |
+
" if attention == \"Yes\":\n",
|
881 |
+
" _, test_accuracy, pred, atten_weights = evaluate(testData,encoder, decoder,output_len,batch_size,hidden_size,num_layers_encoder,num_layers_decoder, cell_type, attention)\n",
|
882 |
+
" else:\n",
|
883 |
+
" _, test_accuracy, pred = evaluate(testData,encoder, decoder,output_len,batch_size,hidden_size,num_layers_encoder,num_layers_decoder, cell_type, attention)\n",
|
884 |
+
" print(f\"test accuracy {test_accuracy}\")\n",
|
885 |
+
"\n",
|
886 |
+
" if attention == \"Yes\":\n",
|
887 |
+
" return pred, atten_weights\n",
|
888 |
+
" else:\n",
|
889 |
+
" return pred\n",
|
890 |
+
" "
|
891 |
+
]
|
892 |
+
},
|
893 |
+
{
|
894 |
+
"cell_type": "markdown",
|
895 |
+
"metadata": {
|
896 |
+
"id": "nvyRJWUUbR2f"
|
897 |
+
},
|
898 |
+
"source": [
|
899 |
+
"# Translating predictions to words\n"
|
900 |
+
]
|
901 |
+
},
|
902 |
+
{
|
903 |
+
"cell_type": "code",
|
904 |
+
"execution_count": null,
|
905 |
+
"metadata": {
|
906 |
+
"id": "Hd3zCTnSbSaL"
|
907 |
+
},
|
908 |
+
"outputs": [],
|
909 |
+
"source": [
|
910 |
+
"def translate_prediction(input_dict , input, output_dict, pred,target):\n",
|
911 |
+
" \n",
|
912 |
+
" '''pred in shape of seq_len-1 * dataset_size\n",
|
913 |
+
" target in shape datasize * seq_len-1\n",
|
914 |
+
" '''\n",
|
915 |
+
" pred = pred.T # shape datasize * seq len-1\n",
|
916 |
+
" pred = pred[1:, :-1] # ignore last index of each row\n",
|
917 |
+
" input = input[:, :-1] # ignore last index of each row\n",
|
918 |
+
" target = target[:, 1:-1] # ignore last index of each row\n",
|
919 |
+
" print(f\"pred shape {pred.shape}, input shape {input.shape}, target shape {target.shape}\")\n",
|
920 |
+
" predictions = [] \n",
|
921 |
+
" Input = [] \n",
|
922 |
+
" Target = []\n",
|
923 |
+
" for i in range(len(pred)):\n",
|
924 |
+
" \n",
|
925 |
+
" pred_word=\"\"\n",
|
926 |
+
" input_word=\"\"\n",
|
927 |
+
" target_word = \"\"\n",
|
928 |
+
"\n",
|
929 |
+
" for j in range(pred.shape[1]):\n",
|
930 |
+
"\n",
|
931 |
+
" # Ignore padding\n",
|
932 |
+
" if(target[i][j].item() != 0):\n",
|
933 |
+
" \n",
|
934 |
+
" pred_word += output_dict[pred[i][j].item()]\n",
|
935 |
+
" target_word += output_dict[target[i][j].item()]\n",
|
936 |
+
" \n",
|
937 |
+
" for j in range(input.shape[1]):\n",
|
938 |
+
" \n",
|
939 |
+
" if(input[i][j].item()!=0):\n",
|
940 |
+
" \n",
|
941 |
+
" input_word += input_dict[input[i][j].item()] \n",
|
942 |
+
"\n",
|
943 |
+
" # Append words in respective List\n",
|
944 |
+
" \n",
|
945 |
+
" predictions.append(pred_word)\n",
|
946 |
+
" Input.append(input_word) \n",
|
947 |
+
" Target.append(target_word) \n",
|
948 |
+
"\n",
|
949 |
+
" # Create a DataFrame\n",
|
950 |
+
" df = pd.DataFrame({\"input\": Input, \"predicted\": predictions,\"Actual\":Target})\n",
|
951 |
+
" return df\n",
|
952 |
+
"\n",
|
953 |
+
" "
|
954 |
+
]
|
955 |
+
},
|
956 |
+
{
|
957 |
+
"cell_type": "markdown",
|
958 |
+
"metadata": {
|
959 |
+
"id": "8ETW0BG_Pa24"
|
960 |
+
},
|
961 |
+
"source": [
|
962 |
+
"#call train"
|
963 |
+
]
|
964 |
+
},
|
965 |
+
{
|
966 |
+
"cell_type": "code",
|
967 |
+
"execution_count": null,
|
968 |
+
"metadata": {
|
969 |
+
"id": "pgGp7MoGzfPg"
|
970 |
+
},
|
971 |
+
"outputs": [],
|
972 |
+
"source": [
|
973 |
+
"# train(sweeps = False, test = True)"
|
974 |
+
]
|
975 |
+
},
|
976 |
+
{
|
977 |
+
"cell_type": "markdown",
|
978 |
+
"metadata": {
|
979 |
+
"id": "MQPGy32rnD3V"
|
980 |
+
},
|
981 |
+
"source": [
|
982 |
+
"# Runnning sweeps for models without Attention\n",
|
983 |
+
"\n"
|
984 |
+
]
|
985 |
+
},
|
986 |
+
{
|
987 |
+
"cell_type": "markdown",
|
988 |
+
"metadata": {
|
989 |
+
"id": "z_aYZvDD1OHU"
|
990 |
+
},
|
991 |
+
"source": [
|
992 |
+
"## Sweep Config"
|
993 |
+
]
|
994 |
+
},
|
995 |
+
{
|
996 |
+
"cell_type": "code",
|
997 |
+
"execution_count": null,
|
998 |
+
"metadata": {
|
999 |
+
"id": "SVv8bI-D1Q_I"
|
1000 |
+
},
|
1001 |
+
"outputs": [],
|
1002 |
+
"source": [
|
1003 |
+
"sweep_config = {\n",
|
1004 |
+
" 'name': 'sweepDL', \n",
|
1005 |
+
" 'method': 'bayes',\n",
|
1006 |
+
" 'metric': {\n",
|
1007 |
+
" 'name': 'val_accuracy',\n",
|
1008 |
+
" 'goal': 'maximize'\n",
|
1009 |
+
" },\n",
|
1010 |
+
" 'parameters': {\n",
|
1011 |
+
" \n",
|
1012 |
+
" 'learn_rate': {\n",
|
1013 |
+
" 'values': [0.01, 0.001, 0.001]\n",
|
1014 |
+
" },\n",
|
1015 |
+
" 'embedding_size': {\n",
|
1016 |
+
" 'values': [32, 64, 128, 256, 512, 1024]\n",
|
1017 |
+
" },\n",
|
1018 |
+
" 'batch_size':{\n",
|
1019 |
+
" 'values':[16, 32, 64, 128, 256]\n",
|
1020 |
+
" },\n",
|
1021 |
+
" 'hidden_size':{\n",
|
1022 |
+
" 'values':[32, 64, 128, 256, 512, 1024]\n",
|
1023 |
+
" },\n",
|
1024 |
+
" 'teach_ratio':{\n",
|
1025 |
+
" 'values':[0.4, 0.5, 0.6]\n",
|
1026 |
+
" },\n",
|
1027 |
+
" 'dropout':{\n",
|
1028 |
+
" 'values':[0, 0.2, 0.4]\n",
|
1029 |
+
" },\n",
|
1030 |
+
" 'cell_type':{\n",
|
1031 |
+
" 'values':[\"RNN\", \"LSTM\", \"GRU\"]\n",
|
1032 |
+
" },\n",
|
1033 |
+
" 'bidirectional':{\n",
|
1034 |
+
" 'values' : [\"Yes\",\"No\"]\n",
|
1035 |
+
" },\n",
|
1036 |
+
" 'num_layers_decoder':{\n",
|
1037 |
+
" 'values': [1,2, 3, 4]\n",
|
1038 |
+
" },\n",
|
1039 |
+
" 'num_layers_encoder':{\n",
|
1040 |
+
" 'values': [1,2,3,4]\n",
|
1041 |
+
" },\n",
|
1042 |
+
" 'epochs':{\n",
|
1043 |
+
" 'values': [10, 15, 20, 25, 30]\n",
|
1044 |
+
" },\n",
|
1045 |
+
" 'attention':{\n",
|
1046 |
+
" 'values': [\"Yes\"]\n",
|
1047 |
+
" }\n",
|
1048 |
+
" \n",
|
1049 |
+
" }\n",
|
1050 |
+
"}\n",
|
1051 |
+
"config_defaults={\n",
|
1052 |
+
" 'learn_rate' : 0.001,\n",
|
1053 |
+
" 'embedding_size': 32,\n",
|
1054 |
+
" 'batch_size': 256,\n",
|
1055 |
+
" 'hidden_size' : 1024,\n",
|
1056 |
+
" 'num_layers_encoder': 3,\n",
|
1057 |
+
" 'num_layers_decoder': 3,\n",
|
1058 |
+
" 'bidirectional': 'No',\n",
|
1059 |
+
" 'cell_type': \"LSTM\",\n",
|
1060 |
+
" 'teach_ratio': 0.6,\n",
|
1061 |
+
" 'dropout': 0.4,\n",
|
1062 |
+
" 'epochs': 15,\n",
|
1063 |
+
" 'attention': \"No\"\n",
|
1064 |
+
"}"
|
1065 |
+
]
|
1066 |
+
},
|
1067 |
+
{
|
1068 |
+
"cell_type": "code",
|
1069 |
+
"execution_count": null,
|
1070 |
+
"metadata": {
|
1071 |
+
"id": "4KxsOOpvr1oi"
|
1072 |
+
},
|
1073 |
+
"outputs": [],
|
1074 |
+
"source": [
|
1075 |
+
"sweep_id=wandb.sweep(sweep_config, project=\"CS6910_Assignment_3\")\n",
|
1076 |
+
"wandb.agent(sweep_id,function=train)"
|
1077 |
+
]
|
1078 |
+
},
|
1079 |
+
{
|
1080 |
+
"cell_type": "markdown",
|
1081 |
+
"metadata": {
|
1082 |
+
"id": "pKvBd5mKf0Hf"
|
1083 |
+
},
|
1084 |
+
"source": [
|
1085 |
+
"# Testing the Best Model(without Attention) on Test Data \n",
|
1086 |
+
"Set default hyperparameters to the best hyperparameters got from sweeps Hyperparamer tuning"
|
1087 |
+
]
|
1088 |
+
},
|
1089 |
+
{
|
1090 |
+
"cell_type": "code",
|
1091 |
+
"execution_count": null,
|
1092 |
+
"metadata": {
|
1093 |
+
"id": "kMQvZjZl0q4U"
|
1094 |
+
},
|
1095 |
+
"outputs": [],
|
1096 |
+
"source": [
|
1097 |
+
"config_defaults={\n",
|
1098 |
+
" 'learn_rate' : 0.001,\n",
|
1099 |
+
" 'embedding_size': 32,\n",
|
1100 |
+
" 'batch_size': 256,\n",
|
1101 |
+
" 'hidden_size' : 1024,\n",
|
1102 |
+
" 'num_layers_encoder': 3,\n",
|
1103 |
+
" 'num_layers_decoder': 3,\n",
|
1104 |
+
" 'bidirectional': 'No',\n",
|
1105 |
+
" 'cell_type': \"LSTM\",\n",
|
1106 |
+
" 'teach_ratio': 0.6,\n",
|
1107 |
+
" 'dropout': 0.4,\n",
|
1108 |
+
" 'epochs': 15,\n",
|
1109 |
+
" 'attention': \"No\"\n",
|
1110 |
+
"}"
|
1111 |
+
]
|
1112 |
+
},
|
1113 |
+
{
|
1114 |
+
"cell_type": "code",
|
1115 |
+
"execution_count": null,
|
1116 |
+
"metadata": {
|
1117 |
+
"colab": {
|
1118 |
+
"base_uri": "https://localhost:8080/"
|
1119 |
+
},
|
1120 |
+
"id": "ygtFpEvp8jFU",
|
1121 |
+
"outputId": "1a71d3be-f17f-498c-8844-3c115c411f0a"
|
1122 |
+
},
|
1123 |
+
"outputs": [
|
1124 |
+
{
|
1125 |
+
"name": "stdout",
|
1126 |
+
"output_type": "stream",
|
1127 |
+
"text": [
|
1128 |
+
"correct = 1490\n",
|
1129 |
+
"test accuracy 36.376953125\n"
|
1130 |
+
]
|
1131 |
+
}
|
1132 |
+
],
|
1133 |
+
"source": [
|
1134 |
+
"pred= train(sweeps = False, test = True)"
|
1135 |
+
]
|
1136 |
+
},
|
1137 |
+
{
|
1138 |
+
"cell_type": "markdown",
|
1139 |
+
"metadata": {
|
1140 |
+
"id": "hMf0OAuscOJx"
|
1141 |
+
},
|
1142 |
+
"source": [
|
1143 |
+
"# Saving the predictions by Vanilla model in csv file"
|
1144 |
+
]
|
1145 |
+
},
|
1146 |
+
{
|
1147 |
+
"cell_type": "code",
|
1148 |
+
"execution_count": null,
|
1149 |
+
"metadata": {
|
1150 |
+
"colab": {
|
1151 |
+
"base_uri": "https://localhost:8080/"
|
1152 |
+
},
|
1153 |
+
"id": "1cgUOUdsfzUB",
|
1154 |
+
"outputId": "8784a3aa-315e-476f-cced-c38ebb8434b3"
|
1155 |
+
},
|
1156 |
+
"outputs": [
|
1157 |
+
{
|
1158 |
+
"name": "stdout",
|
1159 |
+
"output_type": "stream",
|
1160 |
+
"text": [
|
1161 |
+
"pred shape torch.Size([4096, 20]), input shape torch.Size([4096, 26]), target shape torch.Size([4096, 20])\n"
|
1162 |
+
]
|
1163 |
+
}
|
1164 |
+
],
|
1165 |
+
"source": [
|
1166 |
+
"# save the predictions\n",
|
1167 |
+
"dataframe = translate_prediction(ipLang.index2char, testData[:][0], opLang.index2char, pred, testData[:][1])\n",
|
1168 |
+
"dataframe.to_csv(\"predictions.csv\")"
|
1169 |
+
]
|
1170 |
+
},
|
1171 |
+
{
|
1172 |
+
"cell_type": "code",
|
1173 |
+
"execution_count": null,
|
1174 |
+
"metadata": {
|
1175 |
+
"id": "ZZW-IEWZ5syU"
|
1176 |
+
},
|
1177 |
+
"outputs": [],
|
1178 |
+
"source": [
|
1179 |
+
"import pandas as pd\n",
|
1180 |
+
"data = pd.read_csv(\"predictions.csv\")"
|
1181 |
+
]
|
1182 |
+
},
|
1183 |
+
{
|
1184 |
+
"cell_type": "code",
|
1185 |
+
"execution_count": null,
|
1186 |
+
"metadata": {
|
1187 |
+
"colab": {
|
1188 |
+
"base_uri": "https://localhost:8080/",
|
1189 |
+
"height": 424
|
1190 |
+
},
|
1191 |
+
"id": "2sOkc_0vmDlB",
|
1192 |
+
"outputId": "750d06b5-fee2-4eb8-d7e6-a7043cd0c15a"
|
1193 |
+
},
|
1194 |
+
"outputs": [],
|
1195 |
+
"source": [
|
1196 |
+
"data"
|
1197 |
+
]
|
1198 |
+
},
|
1199 |
+
{
|
1200 |
+
"cell_type": "code",
|
1201 |
+
"execution_count": null,
|
1202 |
+
"metadata": {
|
1203 |
+
"colab": {
|
1204 |
+
"base_uri": "https://localhost:8080/",
|
1205 |
+
"height": 142
|
1206 |
+
},
|
1207 |
+
"id": "AkG1vCpZ_vjG",
|
1208 |
+
"outputId": "d64b794c-d173-4871-80fc-93b8211ebedc"
|
1209 |
+
},
|
1210 |
+
"outputs": [],
|
1211 |
+
"source": [
|
1212 |
+
"# We also want to plot the prdiction table to wandb\n",
|
1213 |
+
"wandb.init(project=\"CS6910_Assignment_3\")"
|
1214 |
+
]
|
1215 |
+
},
|
1216 |
+
{
|
1217 |
+
"cell_type": "code",
|
1218 |
+
"execution_count": null,
|
1219 |
+
"metadata": {
|
1220 |
+
"id": "MmKDX6V5_kGu"
|
1221 |
+
},
|
1222 |
+
"outputs": [],
|
1223 |
+
"source": [
|
1224 |
+
"table = wandb.Table(dataframe=data)\n",
|
1225 |
+
"wandb.log({\"data\": table})"
|
1226 |
+
]
|
1227 |
+
},
|
1228 |
+
{
|
1229 |
+
"cell_type": "markdown",
|
1230 |
+
"metadata": {
|
1231 |
+
"id": "FYMa5jTQRUaB"
|
1232 |
+
},
|
1233 |
+
"source": [
|
1234 |
+
"## Plotting the confusion matrix in wandB"
|
1235 |
+
]
|
1236 |
+
},
|
1237 |
+
{
|
1238 |
+
"cell_type": "code",
|
1239 |
+
"execution_count": null,
|
1240 |
+
"metadata": {
|
1241 |
+
"id": "YBaJZCIBRAGZ"
|
1242 |
+
},
|
1243 |
+
"outputs": [],
|
1244 |
+
"source": [
|
1245 |
+
"import numpy as np\n",
|
1246 |
+
"CM = np.zeros((opLang.n_chars, ipLang.n_chars))\n",
|
1247 |
+
"\n",
|
1248 |
+
"for i in range(len(testData[1])):\n",
|
1249 |
+
" for j in range(testData[1].shape[1]):\n",
|
1250 |
+
" pred = int(pred[i][j])\n",
|
1251 |
+
" targ = int(testData[1][i][j])\n",
|
1252 |
+
" CM[pred][targ] += 1\n",
|
1253 |
+
"\n",
|
1254 |
+
"classes =[]\n",
|
1255 |
+
"\n",
|
1256 |
+
"for i in range(len(CM)):\n",
|
1257 |
+
" classes.append(opLang.index2char[i])\n",
|
1258 |
+
"\n",
|
1259 |
+
"percentages = 100 * (CM / np.sum(CM))\n",
|
1260 |
+
"\n",
|
1261 |
+
"# Define the text for each cell\n",
|
1262 |
+
"cell_text = []\n",
|
1263 |
+
"for i in range(len(classes)):\n",
|
1264 |
+
" row_text = []\n",
|
1265 |
+
" for j in range(len(classes)):\n",
|
1266 |
+
"\n",
|
1267 |
+
" txt = \"Total \"+f'{CM[i, j]}Per. ({percentages[i, j]:.3f})'\n",
|
1268 |
+
" if(i==j):\n",
|
1269 |
+
" txt =\"Correcty Predicted \" +classes[i]+\"\"+txt\n",
|
1270 |
+
" if(i!=j):\n",
|
1271 |
+
" txt =\"Predicted \" +classes[j]+\" For \"+classes[i]+\"\"+txt\n",
|
1272 |
+
" row_text.append(txt)\n",
|
1273 |
+
" cell_text.append(row_text)\n",
|
1274 |
+
"\n",
|
1275 |
+
"import plotly.graph_objs as go\n",
|
1276 |
+
"\n",
|
1277 |
+
"# Define the trace\n",
|
1278 |
+
"trace = go.Heatmap(z=percentages,\n",
|
1279 |
+
" x=classes,\n",
|
1280 |
+
" y=classes,\n",
|
1281 |
+
" colorscale='Blues',\n",
|
1282 |
+
" colorbar=dict(title='Percentage'),\n",
|
1283 |
+
" hovertemplate='%{text}%',\n",
|
1284 |
+
" text=cell_text,\n",
|
1285 |
+
" )\n",
|
1286 |
+
"\n",
|
1287 |
+
"# Define the layout\n",
|
1288 |
+
"layout = go.Layout(title='Confusion Matrix',\n",
|
1289 |
+
" xaxis=dict(title='Predicted Character'),\n",
|
1290 |
+
" yaxis=dict(title='True Character'),\n",
|
1291 |
+
" )\n",
|
1292 |
+
"\n",
|
1293 |
+
"# Plot the figure\n",
|
1294 |
+
"fig = go.Figure(data=[trace], layout=layout)\n",
|
1295 |
+
"wandb.log({'confusion_matrix': (fig)})"
|
1296 |
+
]
|
1297 |
+
},
|
1298 |
+
{
|
1299 |
+
"cell_type": "markdown",
|
1300 |
+
"metadata": {
|
1301 |
+
"id": "zfuv5FoA1wt2"
|
1302 |
+
},
|
1303 |
+
"source": [
|
1304 |
+
"# Runnning sweeps for models with Attention\n"
|
1305 |
+
]
|
1306 |
+
},
|
1307 |
+
{
|
1308 |
+
"cell_type": "markdown",
|
1309 |
+
"metadata": {
|
1310 |
+
"id": "tsHS0PkNGHdV"
|
1311 |
+
},
|
1312 |
+
"source": [
|
1313 |
+
"## Sweep Config"
|
1314 |
+
]
|
1315 |
+
},
|
1316 |
+
{
|
1317 |
+
"cell_type": "code",
|
1318 |
+
"execution_count": null,
|
1319 |
+
"metadata": {
|
1320 |
+
"id": "HwCn-Ci5xkTb"
|
1321 |
+
},
|
1322 |
+
"outputs": [],
|
1323 |
+
"source": [
|
1324 |
+
"sweep_config = {\n",
|
1325 |
+
" 'name': 'sweepDL', \n",
|
1326 |
+
" 'method': 'bayes',\n",
|
1327 |
+
" 'metric': {\n",
|
1328 |
+
" 'name': 'val_accuracy',\n",
|
1329 |
+
" 'goal': 'maximize'\n",
|
1330 |
+
" },\n",
|
1331 |
+
" 'parameters': {\n",
|
1332 |
+
" \n",
|
1333 |
+
" 'learn_rate': {\n",
|
1334 |
+
" 'values': [0.01, 0.001, 0.001]\n",
|
1335 |
+
" },\n",
|
1336 |
+
" 'embedding_size': {\n",
|
1337 |
+
" 'values': [32, 64, 128, 256, 512, 1024]\n",
|
1338 |
+
" },\n",
|
1339 |
+
" 'batch_size':{\n",
|
1340 |
+
" 'values':[16, 32, 64, 128, 256]\n",
|
1341 |
+
" },\n",
|
1342 |
+
" 'hidden_size':{\n",
|
1343 |
+
" 'values':[32, 64, 128, 256, 512, 1024]\n",
|
1344 |
+
" },\n",
|
1345 |
+
" 'teach_ratio':{\n",
|
1346 |
+
" 'values':[0.4, 0.5, 0.6]\n",
|
1347 |
+
" },\n",
|
1348 |
+
" 'dropout':{\n",
|
1349 |
+
" 'values':[0, 0.2, 0.4]\n",
|
1350 |
+
" },\n",
|
1351 |
+
" 'cell_type':{\n",
|
1352 |
+
" 'values':[\"RNN\", \"LSTM\", \"GRU\"]\n",
|
1353 |
+
" },\n",
|
1354 |
+
" 'bidirectional':{\n",
|
1355 |
+
" 'values' : [\"Yes\",\"No\"]\n",
|
1356 |
+
" },\n",
|
1357 |
+
" 'num_layers_decoder':{\n",
|
1358 |
+
" 'values': [1,2, 3, 4]\n",
|
1359 |
+
" },\n",
|
1360 |
+
" 'num_layers_encoder':{\n",
|
1361 |
+
" 'values': [1,2,3,4]\n",
|
1362 |
+
" },\n",
|
1363 |
+
" 'epochs':{\n",
|
1364 |
+
" 'values': [10, 15, 20, 25, 30]\n",
|
1365 |
+
" },\n",
|
1366 |
+
" 'attention':{\n",
|
1367 |
+
" 'values': [\"Yes\"]\n",
|
1368 |
+
" }\n",
|
1369 |
+
" \n",
|
1370 |
+
" }\n",
|
1371 |
+
"}\n",
|
1372 |
+
"config_defaults={\n",
|
1373 |
+
" 'learn_rate' : 0.001,\n",
|
1374 |
+
" 'embedding_size': 32,\n",
|
1375 |
+
" 'batch_size': 64,\n",
|
1376 |
+
" 'hidden_size' : 1024,\n",
|
1377 |
+
" 'num_layers_encoder': 1,\n",
|
1378 |
+
" 'num_layers_decoder': 1,\n",
|
1379 |
+
" 'bidirectional': 'Yes',\n",
|
1380 |
+
" 'cell_type': \"LSTM\",\n",
|
1381 |
+
" 'teach_ratio': 0.5,\n",
|
1382 |
+
" 'dropout': 0.4,\n",
|
1383 |
+
" 'epochs': 20,\n",
|
1384 |
+
" 'attention': \"Yes\"\n",
|
1385 |
+
"}"
|
1386 |
+
]
|
1387 |
+
},
|
1388 |
+
{
|
1389 |
+
"cell_type": "code",
|
1390 |
+
"execution_count": null,
|
1391 |
+
"metadata": {
|
1392 |
+
"id": "3ADMwinqaQVF"
|
1393 |
+
},
|
1394 |
+
"outputs": [],
|
1395 |
+
"source": [
|
1396 |
+
"sweep_id=wandb.sweep(sweep_config, project=\"CS6910_Assignment_3\")\n",
|
1397 |
+
"wandb.agent(sweep_id,function=train)\n",
|
1398 |
+
"# wandb.agent(sweep_id= \"xiyggu44\",function=train, project=\"CS6910_Assignment_3\")"
|
1399 |
+
]
|
1400 |
+
},
|
1401 |
+
{
|
1402 |
+
"cell_type": "markdown",
|
1403 |
+
"metadata": {
|
1404 |
+
"id": "W7CYNChRGuGK"
|
1405 |
+
},
|
1406 |
+
"source": [
|
1407 |
+
"# Testing the Best Model(with Attention) on Test Data \n",
|
1408 |
+
"Set default hyperparameters to the best hyperparameters got from sweeps Hyperparamer tuning"
|
1409 |
+
]
|
1410 |
+
},
|
1411 |
+
{
|
1412 |
+
"cell_type": "code",
|
1413 |
+
"execution_count": null,
|
1414 |
+
"metadata": {
|
1415 |
+
"id": "C9MUrsXu_Rr4"
|
1416 |
+
},
|
1417 |
+
"outputs": [],
|
1418 |
+
"source": [
|
1419 |
+
"config_defaults={\n",
|
1420 |
+
" 'learn_rate' : 0.001,\n",
|
1421 |
+
" 'embedding_size': 32,\n",
|
1422 |
+
" 'batch_size': 64,\n",
|
1423 |
+
" 'hidden_size' : 1024,\n",
|
1424 |
+
" 'num_layers_encoder': 1,\n",
|
1425 |
+
" 'num_layers_decoder': 1,\n",
|
1426 |
+
" 'bidirectional': 'Yes',\n",
|
1427 |
+
" 'cell_type': \"LSTM\",\n",
|
1428 |
+
" 'teach_ratio': 0.5,\n",
|
1429 |
+
" 'dropout': 0.4,\n",
|
1430 |
+
" 'epochs': 20,\n",
|
1431 |
+
" 'attention': \"Yes\"\n",
|
1432 |
+
"}"
|
1433 |
+
]
|
1434 |
+
},
|
1435 |
+
{
|
1436 |
+
"cell_type": "code",
|
1437 |
+
"execution_count": null,
|
1438 |
+
"metadata": {
|
1439 |
+
"id": "u7XAB4Q5Hpxj"
|
1440 |
+
},
|
1441 |
+
"outputs": [],
|
1442 |
+
"source": [
|
1443 |
+
"pred, atten_weights = train(sweeps = False, test = True)"
|
1444 |
+
]
|
1445 |
+
},
|
1446 |
+
{
|
1447 |
+
"cell_type": "markdown",
|
1448 |
+
"metadata": {
|
1449 |
+
"id": "fld21YRZdRdG"
|
1450 |
+
},
|
1451 |
+
"source": [
|
1452 |
+
"# Saving the predictions by Vanilla model in csv file"
|
1453 |
+
]
|
1454 |
+
},
|
1455 |
+
{
|
1456 |
+
"cell_type": "code",
|
1457 |
+
"execution_count": null,
|
1458 |
+
"metadata": {
|
1459 |
+
"colab": {
|
1460 |
+
"base_uri": "https://localhost:8080/"
|
1461 |
+
},
|
1462 |
+
"id": "BpDQ1mrydYWg",
|
1463 |
+
"outputId": "8784a3aa-315e-476f-cced-c38ebb8434b3"
|
1464 |
+
},
|
1465 |
+
"outputs": [
|
1466 |
+
{
|
1467 |
+
"name": "stdout",
|
1468 |
+
"output_type": "stream",
|
1469 |
+
"text": [
|
1470 |
+
"pred shape torch.Size([4096, 20]), input shape torch.Size([4096, 26]), target shape torch.Size([4096, 20])\n"
|
1471 |
+
]
|
1472 |
+
}
|
1473 |
+
],
|
1474 |
+
"source": [
|
1475 |
+
"# save the predictions\n",
|
1476 |
+
"dataframe = translate_prediction(ipLang.index2char, testData[:][0], opLang.index2char, pred, testData[:][1])\n",
|
1477 |
+
"dataframe.to_csv(\"predictions.csv\")"
|
1478 |
+
]
|
1479 |
+
},
|
1480 |
+
{
|
1481 |
+
"cell_type": "code",
|
1482 |
+
"execution_count": null,
|
1483 |
+
"metadata": {
|
1484 |
+
"id": "PKMYPZdtdbDh"
|
1485 |
+
},
|
1486 |
+
"outputs": [],
|
1487 |
+
"source": [
|
1488 |
+
"import pandas as pd\n",
|
1489 |
+
"data = pd.read_csv(\"predictions.csv\")"
|
1490 |
+
]
|
1491 |
+
},
|
1492 |
+
{
|
1493 |
+
"cell_type": "code",
|
1494 |
+
"execution_count": null,
|
1495 |
+
"metadata": {
|
1496 |
+
"colab": {
|
1497 |
+
"base_uri": "https://localhost:8080/",
|
1498 |
+
"height": 142
|
1499 |
+
},
|
1500 |
+
"id": "8gCL1rXCdgYp",
|
1501 |
+
"outputId": "d64b794c-d173-4871-80fc-93b8211ebedc"
|
1502 |
+
},
|
1503 |
+
"outputs": [],
|
1504 |
+
"source": [
|
1505 |
+
"# We also want to plot the prdiction table to wandb\n",
|
1506 |
+
"wandb.init(project=\"CS6910_Assignment_3\")"
|
1507 |
+
]
|
1508 |
+
},
|
1509 |
+
{
|
1510 |
+
"cell_type": "code",
|
1511 |
+
"execution_count": null,
|
1512 |
+
"metadata": {
|
1513 |
+
"id": "N1r2ownhdjbz"
|
1514 |
+
},
|
1515 |
+
"outputs": [],
|
1516 |
+
"source": [
|
1517 |
+
"table = wandb.Table(dataframe=data)\n",
|
1518 |
+
"wandb.log({\"data\": table})"
|
1519 |
+
]
|
1520 |
+
},
|
1521 |
+
{
|
1522 |
+
"cell_type": "markdown",
|
1523 |
+
"metadata": {
|
1524 |
+
"id": "LDP4KvWdFnIL"
|
1525 |
+
},
|
1526 |
+
"source": [
|
1527 |
+
"# Plotting the Attention HeatMaps"
|
1528 |
+
]
|
1529 |
+
},
|
1530 |
+
{
|
1531 |
+
"cell_type": "code",
|
1532 |
+
"execution_count": null,
|
1533 |
+
"metadata": {
|
1534 |
+
"colab": {
|
1535 |
+
"base_uri": "https://localhost:8080/",
|
1536 |
+
"height": 1000,
|
1537 |
+
"referenced_widgets": [
|
1538 |
+
"1b0c5a6e21a349cba57322f850ad9f48",
|
1539 |
+
"3aa935a6db14483d8aaada58a84a3e47",
|
1540 |
+
"eabcea7a8bbf42f6aaa3995c0dece721",
|
1541 |
+
"b3b7711edb5542e08c53c4f37da10203",
|
1542 |
+
"39a8a3a9b6f1495ea17fd1b3d86b67c0",
|
1543 |
+
"18a8e2e817b947f9aad87b1ccaf96ea6",
|
1544 |
+
"da62d6e5ad0a462b98e1591d39038e1e",
|
1545 |
+
"9b5bb4f7f4a846c28ab967b64107726e"
|
1546 |
+
]
|
1547 |
+
},
|
1548 |
+
"id": "4WfJEdcgFmiI",
|
1549 |
+
"outputId": "ff266529-4345-4cdc-9860-11914b099052"
|
1550 |
+
},
|
1551 |
+
"outputs": [],
|
1552 |
+
"source": [
|
1553 |
+
"import matplotlib.pyplot as plt\n",
|
1554 |
+
"import numpy as np\n",
|
1555 |
+
"from matplotlib.font_manager import FontProperties\n",
|
1556 |
+
"tel_font = FontProperties(fname = 'TiroDevanagariHindi-Regular.ttf')\n",
|
1557 |
+
"# Assuming you have attention_weights of shape (batch_size, output_sequence_length, batch_size, input_sequence_length)\n",
|
1558 |
+
"# and prediction_matrix of shape (batch_size, output_sequence_length)\n",
|
1559 |
+
"# and input_matrix of shape (batch_size, input_sequence_length)\n",
|
1560 |
+
"\n",
|
1561 |
+
"# Define the grid dimensions\n",
|
1562 |
+
"rows = int(np.ceil(np.sqrt(12)))\n",
|
1563 |
+
"cols = int(np.ceil(12 / rows))\n",
|
1564 |
+
"\n",
|
1565 |
+
"# Create a figure and subplots\n",
|
1566 |
+
"fig, axes = plt.subplots(rows, cols, figsize=(9, 9))\n",
|
1567 |
+
"\n",
|
1568 |
+
"for i, ax in enumerate(axes.flatten()):\n",
|
1569 |
+
" if i < 12:\n",
|
1570 |
+
" prediction = [opLang.index2char[j.item()] for j in pred[i+1]]\n",
|
1571 |
+
" \n",
|
1572 |
+
" pred_word=\"\"\n",
|
1573 |
+
" input_word=\"\"\n",
|
1574 |
+
"\n",
|
1575 |
+
" for j in range(len(prediction)):\n",
|
1576 |
+
" # Ignore padding\n",
|
1577 |
+
" if(prediction[j] != '#'):\n",
|
1578 |
+
" pred_word += prediction[j]\n",
|
1579 |
+
" else : \n",
|
1580 |
+
" break\n",
|
1581 |
+
" input_seq = [ipLang.index2char[j.item()] for j in testData[i][0]]\n",
|
1582 |
+
" \n",
|
1583 |
+
" for j in range(len(input_seq)):\n",
|
1584 |
+
" if(input_seq[j] != '#'):\n",
|
1585 |
+
" input_word += input_seq[j]\n",
|
1586 |
+
" else : \n",
|
1587 |
+
" break\n",
|
1588 |
+
" attn_weights = atten_weights[i, :len(pred_word), :len(input_word)].detach().cpu().numpy()\n",
|
1589 |
+
" ax.imshow(attn_weights.T, cmap='hot', interpolation='nearest')\n",
|
1590 |
+
" ax.xaxis.set_label_position('top')\n",
|
1591 |
+
" ax.set_title(f'Example {i+1}')\n",
|
1592 |
+
" ax.set_xlabel('Output predicted')\n",
|
1593 |
+
" ax.set_ylabel('Input word')\n",
|
1594 |
+
" ax.set_xticks(np.arange(len(pred_word)))\n",
|
1595 |
+
" ax.set_xticklabels(pred_word, rotation = 90, fontproperties = tel_font,fontdict={'fontsize':8})\n",
|
1596 |
+
" ax.xaxis.tick_top()\n",
|
1597 |
+
"\n",
|
1598 |
+
" ax.set_yticks(np.arange(len(input_word)))\n",
|
1599 |
+
" ax.set_yticklabels(input_word, rotation=90)\n",
|
1600 |
+
" \n",
|
1601 |
+
" \n",
|
1602 |
+
"\n",
|
1603 |
+
"# Adjust the spacing between subplots\n",
|
1604 |
+
"plt.tight_layout()\n",
|
1605 |
+
"\n",
|
1606 |
+
"# Show the plot\n",
|
1607 |
+
"plt.show()\n",
|
1608 |
+
"wandb.init(project='CS6910_Assignment_3')\n",
|
1609 |
+
"\n",
|
1610 |
+
"# Convert the matplotlib figure to an image\n",
|
1611 |
+
"fig.canvas.draw()\n",
|
1612 |
+
"image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')\n",
|
1613 |
+
"image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))\n",
|
1614 |
+
"\n",
|
1615 |
+
"# Log the image in wandb\n",
|
1616 |
+
"wandb.log({\"attention_heatmaps\": [wandb.Image(image)]})"
|
1617 |
+
]
|
1618 |
+
},
|
1619 |
+
{
|
1620 |
+
"cell_type": "code",
|
1621 |
+
"execution_count": null,
|
1622 |
+
"metadata": {
|
1623 |
+
"id": "FnHR_oql6-S4"
|
1624 |
+
},
|
1625 |
+
"outputs": [],
|
1626 |
+
"source": []
|
1627 |
+
}
|
1628 |
+
],
|
1629 |
+
"metadata": {
|
1630 |
+
"accelerator": "GPU",
|
1631 |
+
"colab": {
|
1632 |
+
"collapsed_sections": [
|
1633 |
+
"hRdpoWePeYHn",
|
1634 |
+
"44xIRolL_T_d",
|
1635 |
+
"Q1TioafYgICa",
|
1636 |
+
"svxssm9Havhb",
|
1637 |
+
"J56aq1J6a07q",
|
1638 |
+
"5JcQdylzI_Fc",
|
1639 |
+
"658W9RARGEUf",
|
1640 |
+
"q7fAgs5uQni_",
|
1641 |
+
"n4rGh7vuQqaa",
|
1642 |
+
"0SsnRWlgQmCI",
|
1643 |
+
"nvyRJWUUbR2f",
|
1644 |
+
"8ETW0BG_Pa24",
|
1645 |
+
"MQPGy32rnD3V",
|
1646 |
+
"z_aYZvDD1OHU",
|
1647 |
+
"pKvBd5mKf0Hf",
|
1648 |
+
"FYMa5jTQRUaB",
|
1649 |
+
"zfuv5FoA1wt2",
|
1650 |
+
"W7CYNChRGuGK"
|
1651 |
+
],
|
1652 |
+
"gpuType": "T4",
|
1653 |
+
"include_colab_link": true,
|
1654 |
+
"provenance": []
|
1655 |
+
},
|
1656 |
+
"gpuClass": "standard",
|
1657 |
+
"kernelspec": {
|
1658 |
+
"display_name": "Python 3",
|
1659 |
+
"name": "python3"
|
1660 |
+
},
|
1661 |
+
"language_info": {
|
1662 |
+
"name": "python"
|
1663 |
+
},
|
1664 |
+
"widgets": {
|
1665 |
+
"application/vnd.jupyter.widget-state+json": {
|
1666 |
+
"18a8e2e817b947f9aad87b1ccaf96ea6": {
|
1667 |
+
"model_module": "@jupyter-widgets/controls",
|
1668 |
+
"model_module_version": "1.5.0",
|
1669 |
+
"model_name": "DescriptionStyleModel",
|
1670 |
+
"state": {
|
1671 |
+
"_model_module": "@jupyter-widgets/controls",
|
1672 |
+
"_model_module_version": "1.5.0",
|
1673 |
+
"_model_name": "DescriptionStyleModel",
|
1674 |
+
"_view_count": null,
|
1675 |
+
"_view_module": "@jupyter-widgets/base",
|
1676 |
+
"_view_module_version": "1.2.0",
|
1677 |
+
"_view_name": "StyleView",
|
1678 |
+
"description_width": ""
|
1679 |
+
}
|
1680 |
+
},
|
1681 |
+
"1b0c5a6e21a349cba57322f850ad9f48": {
|
1682 |
+
"model_module": "@jupyter-widgets/controls",
|
1683 |
+
"model_module_version": "1.5.0",
|
1684 |
+
"model_name": "VBoxModel",
|
1685 |
+
"state": {
|
1686 |
+
"_dom_classes": [],
|
1687 |
+
"_model_module": "@jupyter-widgets/controls",
|
1688 |
+
"_model_module_version": "1.5.0",
|
1689 |
+
"_model_name": "VBoxModel",
|
1690 |
+
"_view_count": null,
|
1691 |
+
"_view_module": "@jupyter-widgets/controls",
|
1692 |
+
"_view_module_version": "1.5.0",
|
1693 |
+
"_view_name": "VBoxView",
|
1694 |
+
"box_style": "",
|
1695 |
+
"children": [
|
1696 |
+
"IPY_MODEL_3aa935a6db14483d8aaada58a84a3e47",
|
1697 |
+
"IPY_MODEL_eabcea7a8bbf42f6aaa3995c0dece721"
|
1698 |
+
],
|
1699 |
+
"layout": "IPY_MODEL_b3b7711edb5542e08c53c4f37da10203"
|
1700 |
+
}
|
1701 |
+
},
|
1702 |
+
"39a8a3a9b6f1495ea17fd1b3d86b67c0": {
|
1703 |
+
"model_module": "@jupyter-widgets/base",
|
1704 |
+
"model_module_version": "1.2.0",
|
1705 |
+
"model_name": "LayoutModel",
|
1706 |
+
"state": {
|
1707 |
+
"_model_module": "@jupyter-widgets/base",
|
1708 |
+
"_model_module_version": "1.2.0",
|
1709 |
+
"_model_name": "LayoutModel",
|
1710 |
+
"_view_count": null,
|
1711 |
+
"_view_module": "@jupyter-widgets/base",
|
1712 |
+
"_view_module_version": "1.2.0",
|
1713 |
+
"_view_name": "LayoutView",
|
1714 |
+
"align_content": null,
|
1715 |
+
"align_items": null,
|
1716 |
+
"align_self": null,
|
1717 |
+
"border": null,
|
1718 |
+
"bottom": null,
|
1719 |
+
"display": null,
|
1720 |
+
"flex": null,
|
1721 |
+
"flex_flow": null,
|
1722 |
+
"grid_area": null,
|
1723 |
+
"grid_auto_columns": null,
|
1724 |
+
"grid_auto_flow": null,
|
1725 |
+
"grid_auto_rows": null,
|
1726 |
+
"grid_column": null,
|
1727 |
+
"grid_gap": null,
|
1728 |
+
"grid_row": null,
|
1729 |
+
"grid_template_areas": null,
|
1730 |
+
"grid_template_columns": null,
|
1731 |
+
"grid_template_rows": null,
|
1732 |
+
"height": null,
|
1733 |
+
"justify_content": null,
|
1734 |
+
"justify_items": null,
|
1735 |
+
"left": null,
|
1736 |
+
"margin": null,
|
1737 |
+
"max_height": null,
|
1738 |
+
"max_width": null,
|
1739 |
+
"min_height": null,
|
1740 |
+
"min_width": null,
|
1741 |
+
"object_fit": null,
|
1742 |
+
"object_position": null,
|
1743 |
+
"order": null,
|
1744 |
+
"overflow": null,
|
1745 |
+
"overflow_x": null,
|
1746 |
+
"overflow_y": null,
|
1747 |
+
"padding": null,
|
1748 |
+
"right": null,
|
1749 |
+
"top": null,
|
1750 |
+
"visibility": null,
|
1751 |
+
"width": null
|
1752 |
+
}
|
1753 |
+
},
|
1754 |
+
"3aa935a6db14483d8aaada58a84a3e47": {
|
1755 |
+
"model_module": "@jupyter-widgets/controls",
|
1756 |
+
"model_module_version": "1.5.0",
|
1757 |
+
"model_name": "LabelModel",
|
1758 |
+
"state": {
|
1759 |
+
"_dom_classes": [],
|
1760 |
+
"_model_module": "@jupyter-widgets/controls",
|
1761 |
+
"_model_module_version": "1.5.0",
|
1762 |
+
"_model_name": "LabelModel",
|
1763 |
+
"_view_count": null,
|
1764 |
+
"_view_module": "@jupyter-widgets/controls",
|
1765 |
+
"_view_module_version": "1.5.0",
|
1766 |
+
"_view_name": "LabelView",
|
1767 |
+
"description": "",
|
1768 |
+
"description_tooltip": null,
|
1769 |
+
"layout": "IPY_MODEL_39a8a3a9b6f1495ea17fd1b3d86b67c0",
|
1770 |
+
"placeholder": "",
|
1771 |
+
"style": "IPY_MODEL_18a8e2e817b947f9aad87b1ccaf96ea6",
|
1772 |
+
"value": "0.071 MB of 0.071 MB uploaded (0.000 MB deduped)\r"
|
1773 |
+
}
|
1774 |
+
},
|
1775 |
+
"9b5bb4f7f4a846c28ab967b64107726e": {
|
1776 |
+
"model_module": "@jupyter-widgets/controls",
|
1777 |
+
"model_module_version": "1.5.0",
|
1778 |
+
"model_name": "ProgressStyleModel",
|
1779 |
+
"state": {
|
1780 |
+
"_model_module": "@jupyter-widgets/controls",
|
1781 |
+
"_model_module_version": "1.5.0",
|
1782 |
+
"_model_name": "ProgressStyleModel",
|
1783 |
+
"_view_count": null,
|
1784 |
+
"_view_module": "@jupyter-widgets/base",
|
1785 |
+
"_view_module_version": "1.2.0",
|
1786 |
+
"_view_name": "StyleView",
|
1787 |
+
"bar_color": null,
|
1788 |
+
"description_width": ""
|
1789 |
+
}
|
1790 |
+
},
|
1791 |
+
"b3b7711edb5542e08c53c4f37da10203": {
|
1792 |
+
"model_module": "@jupyter-widgets/base",
|
1793 |
+
"model_module_version": "1.2.0",
|
1794 |
+
"model_name": "LayoutModel",
|
1795 |
+
"state": {
|
1796 |
+
"_model_module": "@jupyter-widgets/base",
|
1797 |
+
"_model_module_version": "1.2.0",
|
1798 |
+
"_model_name": "LayoutModel",
|
1799 |
+
"_view_count": null,
|
1800 |
+
"_view_module": "@jupyter-widgets/base",
|
1801 |
+
"_view_module_version": "1.2.0",
|
1802 |
+
"_view_name": "LayoutView",
|
1803 |
+
"align_content": null,
|
1804 |
+
"align_items": null,
|
1805 |
+
"align_self": null,
|
1806 |
+
"border": null,
|
1807 |
+
"bottom": null,
|
1808 |
+
"display": null,
|
1809 |
+
"flex": null,
|
1810 |
+
"flex_flow": null,
|
1811 |
+
"grid_area": null,
|
1812 |
+
"grid_auto_columns": null,
|
1813 |
+
"grid_auto_flow": null,
|
1814 |
+
"grid_auto_rows": null,
|
1815 |
+
"grid_column": null,
|
1816 |
+
"grid_gap": null,
|
1817 |
+
"grid_row": null,
|
1818 |
+
"grid_template_areas": null,
|
1819 |
+
"grid_template_columns": null,
|
1820 |
+
"grid_template_rows": null,
|
1821 |
+
"height": null,
|
1822 |
+
"justify_content": null,
|
1823 |
+
"justify_items": null,
|
1824 |
+
"left": null,
|
1825 |
+
"margin": null,
|
1826 |
+
"max_height": null,
|
1827 |
+
"max_width": null,
|
1828 |
+
"min_height": null,
|
1829 |
+
"min_width": null,
|
1830 |
+
"object_fit": null,
|
1831 |
+
"object_position": null,
|
1832 |
+
"order": null,
|
1833 |
+
"overflow": null,
|
1834 |
+
"overflow_x": null,
|
1835 |
+
"overflow_y": null,
|
1836 |
+
"padding": null,
|
1837 |
+
"right": null,
|
1838 |
+
"top": null,
|
1839 |
+
"visibility": null,
|
1840 |
+
"width": null
|
1841 |
+
}
|
1842 |
+
},
|
1843 |
+
"da62d6e5ad0a462b98e1591d39038e1e": {
|
1844 |
+
"model_module": "@jupyter-widgets/base",
|
1845 |
+
"model_module_version": "1.2.0",
|
1846 |
+
"model_name": "LayoutModel",
|
1847 |
+
"state": {
|
1848 |
+
"_model_module": "@jupyter-widgets/base",
|
1849 |
+
"_model_module_version": "1.2.0",
|
1850 |
+
"_model_name": "LayoutModel",
|
1851 |
+
"_view_count": null,
|
1852 |
+
"_view_module": "@jupyter-widgets/base",
|
1853 |
+
"_view_module_version": "1.2.0",
|
1854 |
+
"_view_name": "LayoutView",
|
1855 |
+
"align_content": null,
|
1856 |
+
"align_items": null,
|
1857 |
+
"align_self": null,
|
1858 |
+
"border": null,
|
1859 |
+
"bottom": null,
|
1860 |
+
"display": null,
|
1861 |
+
"flex": null,
|
1862 |
+
"flex_flow": null,
|
1863 |
+
"grid_area": null,
|
1864 |
+
"grid_auto_columns": null,
|
1865 |
+
"grid_auto_flow": null,
|
1866 |
+
"grid_auto_rows": null,
|
1867 |
+
"grid_column": null,
|
1868 |
+
"grid_gap": null,
|
1869 |
+
"grid_row": null,
|
1870 |
+
"grid_template_areas": null,
|
1871 |
+
"grid_template_columns": null,
|
1872 |
+
"grid_template_rows": null,
|
1873 |
+
"height": null,
|
1874 |
+
"justify_content": null,
|
1875 |
+
"justify_items": null,
|
1876 |
+
"left": null,
|
1877 |
+
"margin": null,
|
1878 |
+
"max_height": null,
|
1879 |
+
"max_width": null,
|
1880 |
+
"min_height": null,
|
1881 |
+
"min_width": null,
|
1882 |
+
"object_fit": null,
|
1883 |
+
"object_position": null,
|
1884 |
+
"order": null,
|
1885 |
+
"overflow": null,
|
1886 |
+
"overflow_x": null,
|
1887 |
+
"overflow_y": null,
|
1888 |
+
"padding": null,
|
1889 |
+
"right": null,
|
1890 |
+
"top": null,
|
1891 |
+
"visibility": null,
|
1892 |
+
"width": null
|
1893 |
+
}
|
1894 |
+
},
|
1895 |
+
"eabcea7a8bbf42f6aaa3995c0dece721": {
|
1896 |
+
"model_module": "@jupyter-widgets/controls",
|
1897 |
+
"model_module_version": "1.5.0",
|
1898 |
+
"model_name": "FloatProgressModel",
|
1899 |
+
"state": {
|
1900 |
+
"_dom_classes": [],
|
1901 |
+
"_model_module": "@jupyter-widgets/controls",
|
1902 |
+
"_model_module_version": "1.5.0",
|
1903 |
+
"_model_name": "FloatProgressModel",
|
1904 |
+
"_view_count": null,
|
1905 |
+
"_view_module": "@jupyter-widgets/controls",
|
1906 |
+
"_view_module_version": "1.5.0",
|
1907 |
+
"_view_name": "ProgressView",
|
1908 |
+
"bar_style": "",
|
1909 |
+
"description": "",
|
1910 |
+
"description_tooltip": null,
|
1911 |
+
"layout": "IPY_MODEL_da62d6e5ad0a462b98e1591d39038e1e",
|
1912 |
+
"max": 1,
|
1913 |
+
"min": 0,
|
1914 |
+
"orientation": "horizontal",
|
1915 |
+
"style": "IPY_MODEL_9b5bb4f7f4a846c28ab967b64107726e",
|
1916 |
+
"value": 1
|
1917 |
+
}
|
1918 |
+
}
|
1919 |
+
}
|
1920 |
+
}
|
1921 |
+
},
|
1922 |
+
"nbformat": 4,
|
1923 |
+
"nbformat_minor": 0
|
1924 |
+
}
|
notebooks/transformers.ipynb
ADDED
@@ -0,0 +1,1929 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"colab_type": "text",
|
7 |
+
"id": "view-in-github"
|
8 |
+
},
|
9 |
+
"source": [
|
10 |
+
"<a href=\"https://colab.research.google.com/github/pankajrawat9075/Language-Transliteration-Model/blob/main/transformers_encoder_decoder.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "markdown",
|
15 |
+
"metadata": {
|
16 |
+
"id": "hRdpoWePeYHn"
|
17 |
+
},
|
18 |
+
"source": [
|
19 |
+
"## Importing Libraries and models"
|
20 |
+
]
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"cell_type": "code",
|
24 |
+
"execution_count": null,
|
25 |
+
"metadata": {
|
26 |
+
"execution": {
|
27 |
+
"iopub.execute_input": "2024-04-06T12:27:53.981869Z",
|
28 |
+
"iopub.status.busy": "2024-04-06T12:27:53.981590Z",
|
29 |
+
"iopub.status.idle": "2024-04-06T12:28:06.958537Z",
|
30 |
+
"shell.execute_reply": "2024-04-06T12:28:06.957350Z",
|
31 |
+
"shell.execute_reply.started": "2024-04-06T12:27:53.981844Z"
|
32 |
+
},
|
33 |
+
"id": "0LBvFtYGCNgJ",
|
34 |
+
"trusted": true
|
35 |
+
},
|
36 |
+
"outputs": [],
|
37 |
+
"source": [
|
38 |
+
"%%capture\n",
|
39 |
+
"!pip install wandb"
|
40 |
+
]
|
41 |
+
},
|
42 |
+
{
|
43 |
+
"cell_type": "code",
|
44 |
+
"execution_count": null,
|
45 |
+
"metadata": {
|
46 |
+
"execution": {
|
47 |
+
"iopub.execute_input": "2024-04-06T12:28:06.960754Z",
|
48 |
+
"iopub.status.busy": "2024-04-06T12:28:06.960461Z",
|
49 |
+
"iopub.status.idle": "2024-04-06T12:28:12.559713Z",
|
50 |
+
"shell.execute_reply": "2024-04-06T12:28:12.558903Z",
|
51 |
+
"shell.execute_reply.started": "2024-04-06T12:28:06.960728Z"
|
52 |
+
},
|
53 |
+
"id": "z4ZVrIumZcDt",
|
54 |
+
"trusted": true
|
55 |
+
},
|
56 |
+
"outputs": [],
|
57 |
+
"source": [
|
58 |
+
"from __future__ import unicode_literals, print_function, division\n",
|
59 |
+
"from io import open\n",
|
60 |
+
"import unicodedata\n",
|
61 |
+
"import string\n",
|
62 |
+
"import re\n",
|
63 |
+
"import wandb\n",
|
64 |
+
"import random\n",
|
65 |
+
"import pandas as pd\n",
|
66 |
+
"import torch\n",
|
67 |
+
"import time\n",
|
68 |
+
"import numpy as np\n",
|
69 |
+
"import torch.nn as nn\n",
|
70 |
+
"from torch import optim\n",
|
71 |
+
"import matplotlib.pyplot as plt\n",
|
72 |
+
"import torch.nn.functional as F\n",
|
73 |
+
"from torch.utils.data import TensorDataset, DataLoader\n",
|
74 |
+
"\n",
|
75 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
76 |
+
"torch.cuda.empty_cache()"
|
77 |
+
]
|
78 |
+
},
|
79 |
+
{
|
80 |
+
"cell_type": "code",
|
81 |
+
"execution_count": null,
|
82 |
+
"metadata": {
|
83 |
+
"colab": {
|
84 |
+
"base_uri": "https://localhost:8080/"
|
85 |
+
},
|
86 |
+
"execution": {
|
87 |
+
"iopub.execute_input": "2024-04-06T12:28:12.561336Z",
|
88 |
+
"iopub.status.busy": "2024-04-06T12:28:12.560805Z",
|
89 |
+
"iopub.status.idle": "2024-04-06T12:28:12.571498Z",
|
90 |
+
"shell.execute_reply": "2024-04-06T12:28:12.570579Z",
|
91 |
+
"shell.execute_reply.started": "2024-04-06T12:28:12.561311Z"
|
92 |
+
},
|
93 |
+
"id": "qwL09v65CIse",
|
94 |
+
"outputId": "5ea72523-6a50-474c-b617-b77e16d72ef3",
|
95 |
+
"trusted": true
|
96 |
+
},
|
97 |
+
"outputs": [
|
98 |
+
{
|
99 |
+
"name": "stdout",
|
100 |
+
"output_type": "stream",
|
101 |
+
"text": [
|
102 |
+
"cuda\n"
|
103 |
+
]
|
104 |
+
}
|
105 |
+
],
|
106 |
+
"source": [
|
107 |
+
"print(device)"
|
108 |
+
]
|
109 |
+
},
|
110 |
+
{
|
111 |
+
"cell_type": "markdown",
|
112 |
+
"metadata": {
|
113 |
+
"id": "44xIRolL_T_d"
|
114 |
+
},
|
115 |
+
"source": [
|
116 |
+
"## Load Dataset"
|
117 |
+
]
|
118 |
+
},
|
119 |
+
{
|
120 |
+
"cell_type": "code",
|
121 |
+
"execution_count": null,
|
122 |
+
"metadata": {
|
123 |
+
"execution": {
|
124 |
+
"iopub.execute_input": "2024-04-06T12:28:12.573774Z",
|
125 |
+
"iopub.status.busy": "2024-04-06T12:28:12.573504Z",
|
126 |
+
"iopub.status.idle": "2024-04-06T12:28:12.583678Z",
|
127 |
+
"shell.execute_reply": "2024-04-06T12:28:12.582875Z",
|
128 |
+
"shell.execute_reply.started": "2024-04-06T12:28:12.573751Z"
|
129 |
+
},
|
130 |
+
"id": "Y4zemXiyE6Fi",
|
131 |
+
"trusted": true
|
132 |
+
},
|
133 |
+
"outputs": [],
|
134 |
+
"source": [
|
135 |
+
"class Language:\n",
|
136 |
+
" def __init__(self, name):\n",
|
137 |
+
" self.name = name\n",
|
138 |
+
" self.char2index = {'#': 0, '$': 1, '^': 2} # '^': start of sequence, '$' : unknown char, '#' : padding\n",
|
139 |
+
" self.index2char = {0: '#', 1: '$', 2: '^'}\n",
|
140 |
+
" self.vocab_size = 3 # Count\n",
|
141 |
+
"\n",
|
142 |
+
" def addWord(self, word):\n",
|
143 |
+
" for char in word:\n",
|
144 |
+
" self.addChar(char)\n",
|
145 |
+
"\n",
|
146 |
+
" def addChar(self, char):\n",
|
147 |
+
" if char not in self.char2index:\n",
|
148 |
+
" self.char2index[char] = self.vocab_size\n",
|
149 |
+
" self.index2char[self.vocab_size] = char\n",
|
150 |
+
" self.vocab_size += 1\n",
|
151 |
+
"\n",
|
152 |
+
" def encode(self, s):\n",
|
153 |
+
" return [self.char2index[ch] for ch in s]\n",
|
154 |
+
"\n",
|
155 |
+
" def decode(self, l):\n",
|
156 |
+
" return ''.join([self.index2char[i] for i in l])\n",
|
157 |
+
"\n",
|
158 |
+
" def vocab(self):\n",
|
159 |
+
" return self.char2index.keys()\n"
|
160 |
+
]
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"cell_type": "code",
|
164 |
+
"execution_count": null,
|
165 |
+
"metadata": {
|
166 |
+
"execution": {
|
167 |
+
"iopub.execute_input": "2024-04-06T12:28:12.584802Z",
|
168 |
+
"iopub.status.busy": "2024-04-06T12:28:12.584565Z",
|
169 |
+
"iopub.status.idle": "2024-04-06T12:28:12.594791Z",
|
170 |
+
"shell.execute_reply": "2024-04-06T12:28:12.593973Z",
|
171 |
+
"shell.execute_reply.started": "2024-04-06T12:28:12.584781Z"
|
172 |
+
},
|
173 |
+
"id": "IDGaCO8DkYpc",
|
174 |
+
"trusted": true
|
175 |
+
},
|
176 |
+
"outputs": [],
|
177 |
+
"source": [
|
178 |
+
"input_shape = 0\n",
|
179 |
+
"def preprocess(data, input_lang, output_lang, s=''):\n",
|
180 |
+
"\n",
|
181 |
+
" unknown = input_lang.char2index['$']\n",
|
182 |
+
"\n",
|
183 |
+
" input_max_len = 27\n",
|
184 |
+
" output_max_len = max([len(o) for o in data[1]])\n",
|
185 |
+
"\n",
|
186 |
+
" n = len(data)\n",
|
187 |
+
" input = torch.zeros((n, input_max_len + 1), device = device)\n",
|
188 |
+
" output = torch.zeros((n, output_max_len + 2), device = device)\n",
|
189 |
+
"\n",
|
190 |
+
" for i in range(n):\n",
|
191 |
+
"\n",
|
192 |
+
" inp = data[0][i].ljust(input_max_len + 1, '#')\n",
|
193 |
+
" op = '^' + data[1][i] # add start symbol to output\n",
|
194 |
+
" op = op.ljust(output_max_len + 2, '#')\n",
|
195 |
+
"\n",
|
196 |
+
" for index, char in enumerate(inp):\n",
|
197 |
+
" if char in input_lang.char2index:\n",
|
198 |
+
" input[i][index] = input_lang.char2index[char]\n",
|
199 |
+
" else:\n",
|
200 |
+
" input[i][index] = unknown\n",
|
201 |
+
"\n",
|
202 |
+
" for index, char in enumerate(op):\n",
|
203 |
+
" if char in output_lang.char2index:\n",
|
204 |
+
" output[i][index] = output_lang.char2index[char]\n",
|
205 |
+
" else:\n",
|
206 |
+
" output[i][index] = unknown\n",
|
207 |
+
"\n",
|
208 |
+
" print(s, ' dataset')\n",
|
209 |
+
" print(input.shape)\n",
|
210 |
+
" print(output.shape)\n",
|
211 |
+
"\n",
|
212 |
+
" return TensorDataset(input.to(torch.int32), output.to(torch.int32))"
|
213 |
+
]
|
214 |
+
},
|
215 |
+
{
|
216 |
+
"cell_type": "code",
|
217 |
+
"execution_count": null,
|
218 |
+
"metadata": {
|
219 |
+
"colab": {
|
220 |
+
"base_uri": "https://localhost:8080/"
|
221 |
+
},
|
222 |
+
"execution": {
|
223 |
+
"iopub.execute_input": "2024-04-06T12:28:12.596018Z",
|
224 |
+
"iopub.status.busy": "2024-04-06T12:28:12.595741Z",
|
225 |
+
"iopub.status.idle": "2024-04-06T12:29:16.322883Z",
|
226 |
+
"shell.execute_reply": "2024-04-06T12:29:16.321877Z",
|
227 |
+
"shell.execute_reply.started": "2024-04-06T12:28:12.595995Z"
|
228 |
+
},
|
229 |
+
"id": "PdS5OXKxfdCX",
|
230 |
+
"outputId": "283fb51a-9a4a-4fc5-bad1-ea66373b29b4",
|
231 |
+
"trusted": true
|
232 |
+
},
|
233 |
+
"outputs": [
|
234 |
+
{
|
235 |
+
"name": "stdout",
|
236 |
+
"output_type": "stream",
|
237 |
+
"text": [
|
238 |
+
"train dataset\n",
|
239 |
+
"torch.Size([51200, 28])\n",
|
240 |
+
"torch.Size([51200, 22])\n",
|
241 |
+
"validation dataset\n",
|
242 |
+
"torch.Size([4096, 28])\n",
|
243 |
+
"torch.Size([4096, 22])\n",
|
244 |
+
"test dataset\n",
|
245 |
+
"torch.Size([4096, 28])\n",
|
246 |
+
"torch.Size([4096, 22])\n"
|
247 |
+
]
|
248 |
+
}
|
249 |
+
],
|
250 |
+
"source": [
|
251 |
+
"def load_prepare_data(lang):\n",
|
252 |
+
"\n",
|
253 |
+
" train_df = pd.read_csv(f\"drive/MyDrive/aksharantar_sampled/{lang}/{lang}_train.csv\", header = None)\n",
|
254 |
+
" val_df = pd.read_csv(f\"drive/MyDrive/aksharantar_sampled/{lang}/{lang}_valid.csv\", header = None)\n",
|
255 |
+
" test_df = pd.read_csv(f\"drive/MyDrive/aksharantar_sampled/{lang}/{lang}_test.csv\", header = None)\n",
|
256 |
+
"\n",
|
257 |
+
" input_lang = Language('eng')\n",
|
258 |
+
" output_lang = Language(lang)\n",
|
259 |
+
"\n",
|
260 |
+
" # create vocablury\n",
|
261 |
+
" for i in range(len(train_df)):\n",
|
262 |
+
" input_lang.addWord(train_df[0][i]) # 'eng'\n",
|
263 |
+
" output_lang.addWord(train_df[1][i]) # 'hin'\n",
|
264 |
+
"\n",
|
265 |
+
" # encode the datasets\n",
|
266 |
+
" train_data = preprocess(train_df, input_lang, output_lang, 'train')\n",
|
267 |
+
" val_data = preprocess(val_df, input_lang, output_lang, 'validation')\n",
|
268 |
+
" test_data = preprocess(test_df, input_lang, output_lang, 'test')\n",
|
269 |
+
"\n",
|
270 |
+
" return train_data, val_data, test_data, input_lang, output_lang\n",
|
271 |
+
"\n",
|
272 |
+
"\n",
|
273 |
+
"train_data, val_data, test_data, input_lang, output_lang = load_prepare_data('hin')\n"
|
274 |
+
]
|
275 |
+
},
|
276 |
+
{
|
277 |
+
"cell_type": "code",
|
278 |
+
"execution_count": null,
|
279 |
+
"metadata": {
|
280 |
+
"colab": {
|
281 |
+
"base_uri": "https://localhost:8080/"
|
282 |
+
},
|
283 |
+
"execution": {
|
284 |
+
"iopub.execute_input": "2024-04-06T12:29:16.324674Z",
|
285 |
+
"iopub.status.busy": "2024-04-06T12:29:16.324273Z",
|
286 |
+
"iopub.status.idle": "2024-04-06T12:29:16.334834Z",
|
287 |
+
"shell.execute_reply": "2024-04-06T12:29:16.333992Z",
|
288 |
+
"shell.execute_reply.started": "2024-04-06T12:29:16.324643Z"
|
289 |
+
},
|
290 |
+
"id": "nu-NTR6BDj8e",
|
291 |
+
"outputId": "bd3dba2a-092d-4846-a5fb-f703f119b56a",
|
292 |
+
"trusted": true
|
293 |
+
},
|
294 |
+
"outputs": [
|
295 |
+
{
|
296 |
+
"name": "stdout",
|
297 |
+
"output_type": "stream",
|
298 |
+
"text": [
|
299 |
+
"hankers#####################\n"
|
300 |
+
]
|
301 |
+
},
|
302 |
+
{
|
303 |
+
"data": {
|
304 |
+
"text/plain": [
|
305 |
+
"'^हैंकर्स##############'"
|
306 |
+
]
|
307 |
+
},
|
308 |
+
"execution_count": 7,
|
309 |
+
"metadata": {},
|
310 |
+
"output_type": "execute_result"
|
311 |
+
}
|
312 |
+
],
|
313 |
+
"source": [
|
314 |
+
"print(input_lang.decode(train_data[23][0].tolist()))\n",
|
315 |
+
"output_lang.decode(train_data[23][1].tolist())"
|
316 |
+
]
|
317 |
+
},
|
318 |
+
{
|
319 |
+
"cell_type": "code",
|
320 |
+
"execution_count": null,
|
321 |
+
"metadata": {
|
322 |
+
"colab": {
|
323 |
+
"base_uri": "https://localhost:8080/"
|
324 |
+
},
|
325 |
+
"execution": {
|
326 |
+
"iopub.execute_input": "2024-04-06T12:29:16.336734Z",
|
327 |
+
"iopub.status.busy": "2024-04-06T12:29:16.336128Z",
|
328 |
+
"iopub.status.idle": "2024-04-06T12:29:16.355166Z",
|
329 |
+
"shell.execute_reply": "2024-04-06T12:29:16.354327Z",
|
330 |
+
"shell.execute_reply.started": "2024-04-06T12:29:16.336702Z"
|
331 |
+
},
|
332 |
+
"id": "yJI8iU6dBSE0",
|
333 |
+
"outputId": "818815ee-503e-4dcd-b7a6-5f00a06b5ace",
|
334 |
+
"trusted": true
|
335 |
+
},
|
336 |
+
"outputs": [
|
337 |
+
{
|
338 |
+
"data": {
|
339 |
+
"text/plain": [
|
340 |
+
"tensor([ 2, 34, 36, 17, 15, 7, 5, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
341 |
+
" 0, 0, 0, 0], device='cuda:0', dtype=torch.int32)"
|
342 |
+
]
|
343 |
+
},
|
344 |
+
"execution_count": 8,
|
345 |
+
"metadata": {},
|
346 |
+
"output_type": "execute_result"
|
347 |
+
}
|
348 |
+
],
|
349 |
+
"source": [
|
350 |
+
"train_data[23][1]"
|
351 |
+
]
|
352 |
+
},
|
353 |
+
{
|
354 |
+
"cell_type": "code",
|
355 |
+
"execution_count": null,
|
356 |
+
"metadata": {
|
357 |
+
"colab": {
|
358 |
+
"base_uri": "https://localhost:8080/"
|
359 |
+
},
|
360 |
+
"execution": {
|
361 |
+
"iopub.execute_input": "2024-04-06T12:29:16.356467Z",
|
362 |
+
"iopub.status.busy": "2024-04-06T12:29:16.356175Z",
|
363 |
+
"iopub.status.idle": "2024-04-06T12:29:19.315416Z",
|
364 |
+
"shell.execute_reply": "2024-04-06T12:29:19.314522Z",
|
365 |
+
"shell.execute_reply.started": "2024-04-06T12:29:16.356444Z"
|
366 |
+
},
|
367 |
+
"id": "SvmzS5Lt_Jnl",
|
368 |
+
"outputId": "1387d646-ea3c-4fbf-b44f-c071e2b07784",
|
369 |
+
"trusted": true
|
370 |
+
},
|
371 |
+
"outputs": [
|
372 |
+
{
|
373 |
+
"name": "stderr",
|
374 |
+
"output_type": "stream",
|
375 |
+
"text": [
|
376 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: W&B API key is configured. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n",
|
377 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m If you're specifying your api key in code, ensure this code is not shared publicly.\n",
|
378 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Consider setting the WANDB_API_KEY environment variable, or running `wandb login` from the command line.\n",
|
379 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n"
|
380 |
+
]
|
381 |
+
},
|
382 |
+
{
|
383 |
+
"data": {
|
384 |
+
"text/plain": [
|
385 |
+
"True"
|
386 |
+
]
|
387 |
+
},
|
388 |
+
"execution_count": 9,
|
389 |
+
"metadata": {},
|
390 |
+
"output_type": "execute_result"
|
391 |
+
}
|
392 |
+
],
|
393 |
+
"source": [
|
394 |
+
"wandb.login(key =\"\")"
|
395 |
+
]
|
396 |
+
},
|
397 |
+
{
|
398 |
+
"cell_type": "markdown",
|
399 |
+
"metadata": {
|
400 |
+
"id": "Q1TioafYgICa"
|
401 |
+
},
|
402 |
+
"source": [
|
403 |
+
"# seq2seq tranformer model"
|
404 |
+
]
|
405 |
+
},
|
406 |
+
{
|
407 |
+
"cell_type": "markdown",
|
408 |
+
"metadata": {
|
409 |
+
"id": "K94_u35dCk7-"
|
410 |
+
},
|
411 |
+
"source": [
|
412 |
+
"### hyperparameter settings"
|
413 |
+
]
|
414 |
+
},
|
415 |
+
{
|
416 |
+
"cell_type": "code",
|
417 |
+
"execution_count": null,
|
418 |
+
"metadata": {
|
419 |
+
"execution": {
|
420 |
+
"iopub.execute_input": "2024-04-06T12:29:19.318625Z",
|
421 |
+
"iopub.status.busy": "2024-04-06T12:29:19.318195Z",
|
422 |
+
"iopub.status.idle": "2024-04-06T12:29:19.324068Z",
|
423 |
+
"shell.execute_reply": "2024-04-06T12:29:19.323194Z",
|
424 |
+
"shell.execute_reply.started": "2024-04-06T12:29:19.318601Z"
|
425 |
+
},
|
426 |
+
"id": "PugX7KHvc65u",
|
427 |
+
"trusted": true
|
428 |
+
},
|
429 |
+
"outputs": [],
|
430 |
+
"source": [
|
431 |
+
"n_embd = 64\n",
|
432 |
+
"batch_size = 256\n",
|
433 |
+
"learning_rate = 1e-3\n",
|
434 |
+
"n_head = 4 # other options factors of 32 like 2, 8\n",
|
435 |
+
"n_layers = 6\n",
|
436 |
+
"dropout = 0.2\n",
|
437 |
+
"epochs = 50\n",
|
438 |
+
"\n",
|
439 |
+
"# encoder specific detail\n",
|
440 |
+
"input_vocab_size = input_lang.vocab_size\n",
|
441 |
+
"encoder_block_size = len(train_data[0][0])\n",
|
442 |
+
"\n",
|
443 |
+
"# decoder specific detail\n",
|
444 |
+
"output_vocab_size = output_lang.vocab_size\n",
|
445 |
+
"decoder_block_size = len(train_data[0][1])"
|
446 |
+
]
|
447 |
+
},
|
448 |
+
{
|
449 |
+
"cell_type": "markdown",
|
450 |
+
"metadata": {
|
451 |
+
"id": "XdltQ7oJCq1j"
|
452 |
+
},
|
453 |
+
"source": [
|
454 |
+
"### Encoder model"
|
455 |
+
]
|
456 |
+
},
|
457 |
+
{
|
458 |
+
"cell_type": "code",
|
459 |
+
"execution_count": null,
|
460 |
+
"metadata": {
|
461 |
+
"execution": {
|
462 |
+
"iopub.execute_input": "2024-04-06T12:29:19.325685Z",
|
463 |
+
"iopub.status.busy": "2024-04-06T12:29:19.325424Z",
|
464 |
+
"iopub.status.idle": "2024-04-06T12:29:19.351414Z",
|
465 |
+
"shell.execute_reply": "2024-04-06T12:29:19.350579Z",
|
466 |
+
"shell.execute_reply.started": "2024-04-06T12:29:19.325663Z"
|
467 |
+
},
|
468 |
+
"id": "uiluDiY7FAMU",
|
469 |
+
"trusted": true
|
470 |
+
},
|
471 |
+
"outputs": [],
|
472 |
+
"source": [
|
473 |
+
"class Head(nn.Module):\n",
|
474 |
+
" \"\"\" one self-attention head \"\"\"\n",
|
475 |
+
"\n",
|
476 |
+
" def __init__(self, n_embd, d_k, dropout, mask=0): # d_k is dimention of key , nomaly d_k = n_embd / 4\n",
|
477 |
+
" super().__init__()\n",
|
478 |
+
" self.mask = mask\n",
|
479 |
+
" self.key = nn.Linear(n_embd, d_k, bias=False, device=device)\n",
|
480 |
+
" self.query = nn.Linear(n_embd, d_k, bias=False, device=device)\n",
|
481 |
+
" self.value = nn.Linear(n_embd, d_k, bias=False, device=device)\n",
|
482 |
+
" if mask:\n",
|
483 |
+
" self.register_buffer('tril', torch.tril(torch.ones(encoder_block_size, encoder_block_size, device=device)))\n",
|
484 |
+
" self.dropout = nn.Dropout(dropout)\n",
|
485 |
+
"\n",
|
486 |
+
" def forward(self, x, encoder_output = None):\n",
|
487 |
+
" B,T,C = x.shape\n",
|
488 |
+
"\n",
|
489 |
+
" if encoder_output is not None:\n",
|
490 |
+
" k = self.key(encoder_output)\n",
|
491 |
+
" Be, Te, Ce = encoder_output.shape\n",
|
492 |
+
" else:\n",
|
493 |
+
" k = self.key(x) # (B,T,d_k)\n",
|
494 |
+
"\n",
|
495 |
+
" q = self.query(x) # (B,T,d_k)\n",
|
496 |
+
" # compute attention scores\n",
|
497 |
+
" wei = q @ k.transpose(-2, -1) * C**-0.5 # (B,T,T)\n",
|
498 |
+
"\n",
|
499 |
+
" if self.mask:\n",
|
500 |
+
" if encoder_output is not None:\n",
|
501 |
+
" wei = wei.masked_fill(self.tril[:T, :Te] == 0, float('-inf')) # (B,T,T)\n",
|
502 |
+
" else:\n",
|
503 |
+
" wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B,T,T)\n",
|
504 |
+
"\n",
|
505 |
+
" wei = F.softmax(wei, dim=-1)\n",
|
506 |
+
" wei = self.dropout(wei)\n",
|
507 |
+
" # perform weighted aggregation of values\n",
|
508 |
+
" if encoder_output is not None:\n",
|
509 |
+
" v = self.value(encoder_output)\n",
|
510 |
+
" else:\n",
|
511 |
+
" v = self.value(x)\n",
|
512 |
+
" out = wei @ v # (B,T,C)\n",
|
513 |
+
" return out\n",
|
514 |
+
"\n",
|
515 |
+
"class MultiHeadAttention(nn.Module):\n",
|
516 |
+
" \"\"\" multiple self attention heads in parallel \"\"\"\n",
|
517 |
+
"\n",
|
518 |
+
" def __init__(self, n_embd, num_head, d_k, dropout, mask=0):\n",
|
519 |
+
" super().__init__()\n",
|
520 |
+
" self.heads = nn.ModuleList([Head(n_embd, d_k, dropout, mask) for _ in range(num_head)])\n",
|
521 |
+
" self.proj = nn.Linear(n_embd, n_embd)\n",
|
522 |
+
" self.dropout = nn.Dropout(dropout)\n",
|
523 |
+
"\n",
|
524 |
+
" def forward(self, x, encoder_output=None):\n",
|
525 |
+
" out = torch.cat([h(x, encoder_output) for h in self.heads], dim=-1)\n",
|
526 |
+
" out = self.dropout(self.proj(out))\n",
|
527 |
+
" return out\n",
|
528 |
+
"\n",
|
529 |
+
"class FeedForward(nn.Module):\n",
|
530 |
+
" \"\"\" multiple self attention heads in parallel \"\"\"\n",
|
531 |
+
"\n",
|
532 |
+
" def __init__(self, n_embd, dropout):\n",
|
533 |
+
" super().__init__()\n",
|
534 |
+
" self.net = nn.Sequential(\n",
|
535 |
+
" nn.Linear(n_embd, 4 * n_embd),\n",
|
536 |
+
" nn.ReLU(),\n",
|
537 |
+
" nn.Linear(4 * n_embd, n_embd),\n",
|
538 |
+
" nn.Dropout(dropout)\n",
|
539 |
+
" )\n",
|
540 |
+
"\n",
|
541 |
+
" def forward(self, x):\n",
|
542 |
+
" return self.net(x)\n",
|
543 |
+
"\n",
|
544 |
+
"class encoderBlock(nn.Module):\n",
|
545 |
+
" \"\"\" Tranformer encoder block : communication followed by computation \"\"\"\n",
|
546 |
+
"\n",
|
547 |
+
" def __init__(self, n_embd, n_head, dropout):\n",
|
548 |
+
" super().__init__()\n",
|
549 |
+
" d_k = n_embd // n_head\n",
|
550 |
+
" self.sa = MultiHeadAttention(n_embd, n_head, d_k, dropout)\n",
|
551 |
+
" self.ffwd = FeedForward(n_embd, dropout)\n",
|
552 |
+
" self.ln1 = nn.LayerNorm(n_embd)\n",
|
553 |
+
" self.ln2 = nn.LayerNorm(n_embd)\n",
|
554 |
+
"\n",
|
555 |
+
" def forward(self, x, encoder_output=None):\n",
|
556 |
+
" x = x + self.sa(self.ln1(x), encoder_output)\n",
|
557 |
+
" x = x + self.ffwd(self.ln2(x))\n",
|
558 |
+
" return x\n",
|
559 |
+
"\n",
|
560 |
+
"class Encoder(nn.Module):\n",
|
561 |
+
"\n",
|
562 |
+
" def __init__(self, n_embd, n_head, n_layers, dropout):\n",
|
563 |
+
" super().__init__()\n",
|
564 |
+
"\n",
|
565 |
+
" self.token_embedding_table = nn.Embedding(input_vocab_size, n_embd) # n_embd: input embedding dimension\n",
|
566 |
+
" self.position_embedding_table = nn.Embedding(encoder_block_size, n_embd)\n",
|
567 |
+
" self.blocks = nn.Sequential(*[encoderBlock(n_embd, n_head, dropout) for _ in range(n_layers)])\n",
|
568 |
+
" self.ln_f = nn.LayerNorm(n_embd) # final layer norm\n",
|
569 |
+
"\n",
|
570 |
+
" def forward(self, idx):\n",
|
571 |
+
" B, T = idx.shape\n",
|
572 |
+
" tok_emb = self.token_embedding_table(idx) # (B,T,n_embd)\n",
|
573 |
+
" pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,n_embd)\n",
|
574 |
+
" x = tok_emb + pos_emb # (B,T,n_embd)\n",
|
575 |
+
" x = self.blocks(x) # apply one attention layer (B,T,C)\n",
|
576 |
+
" x = self.ln_f(x) # (B,T,C)\n",
|
577 |
+
" return x\n"
|
578 |
+
]
|
579 |
+
},
|
580 |
+
{
|
581 |
+
"cell_type": "markdown",
|
582 |
+
"metadata": {
|
583 |
+
"id": "GgPU486JC8Mz"
|
584 |
+
},
|
585 |
+
"source": [
|
586 |
+
"### Decoder model"
|
587 |
+
]
|
588 |
+
},
|
589 |
+
{
|
590 |
+
"cell_type": "code",
|
591 |
+
"execution_count": null,
|
592 |
+
"metadata": {
|
593 |
+
"execution": {
|
594 |
+
"iopub.execute_input": "2024-04-06T12:29:19.352896Z",
|
595 |
+
"iopub.status.busy": "2024-04-06T12:29:19.352571Z",
|
596 |
+
"iopub.status.idle": "2024-04-06T12:29:19.367829Z",
|
597 |
+
"shell.execute_reply": "2024-04-06T12:29:19.366971Z",
|
598 |
+
"shell.execute_reply.started": "2024-04-06T12:29:19.352872Z"
|
599 |
+
},
|
600 |
+
"id": "JteOV0CdC_bv",
|
601 |
+
"trusted": true
|
602 |
+
},
|
603 |
+
"outputs": [],
|
604 |
+
"source": [
|
605 |
+
"class decoderBlock(nn.Module):\n",
|
606 |
+
" \"\"\" Tranformer decoder block : self communication then cross communication followed by computation \"\"\"\n",
|
607 |
+
"\n",
|
608 |
+
" def __init__(self, n_embd, n_head, dropout):\n",
|
609 |
+
" super().__init__()\n",
|
610 |
+
" d_k = n_embd // n_head\n",
|
611 |
+
" self.sa = MultiHeadAttention(n_embd, n_head, d_k, dropout, mask = 1)\n",
|
612 |
+
" self.ca = MultiHeadAttention(n_embd, n_head, d_k, dropout, mask = 1)\n",
|
613 |
+
" self.ffwd = FeedForward(n_embd, dropout)\n",
|
614 |
+
" self.ln1 = nn.LayerNorm(n_embd, device=device)\n",
|
615 |
+
" self.ln2 = nn.LayerNorm(n_embd, device=device)\n",
|
616 |
+
" self.ln3 = nn.LayerNorm(n_embd, device=device)\n",
|
617 |
+
"\n",
|
618 |
+
" def forward(self, x_encoder_output):\n",
|
619 |
+
" x = x_encoder_output[0]\n",
|
620 |
+
" encoder_output = x_encoder_output[1]\n",
|
621 |
+
" x = x + self.sa(self.ln1(x))\n",
|
622 |
+
" x = x + self.ca(self.ln2(x), encoder_output)\n",
|
623 |
+
" x = x + self.ffwd(self.ln3(x))\n",
|
624 |
+
" return (x,encoder_output)\n",
|
625 |
+
"\n",
|
626 |
+
"class Decoder(nn.Module):\n",
|
627 |
+
"\n",
|
628 |
+
" def __init__(self, n_embd, n_head, n_layers, dropout):\n",
|
629 |
+
" super().__init__()\n",
|
630 |
+
"\n",
|
631 |
+
" self.token_embedding_table = nn.Embedding(output_vocab_size, n_embd) # n_embd: input embedding dimension\n",
|
632 |
+
" self.position_embedding_table = nn.Embedding(decoder_block_size, n_embd)\n",
|
633 |
+
" self.blocks = nn.Sequential(*[decoderBlock(n_embd, n_head=n_head, dropout=dropout) for _ in range(n_layers)])\n",
|
634 |
+
" self.ln_f = nn.LayerNorm(n_embd) # final layer norm\n",
|
635 |
+
" self.lm_head = nn.Linear(n_embd, output_vocab_size)\n",
|
636 |
+
"\n",
|
637 |
+
" def forward(self, idx, encoder_output, targets=None):\n",
|
638 |
+
" B, T = idx.shape\n",
|
639 |
+
"\n",
|
640 |
+
" tok_emb = self.token_embedding_table(idx) # (B,T,n_embd)\n",
|
641 |
+
" pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,n_embd)\n",
|
642 |
+
" x = tok_emb + pos_emb # (B,T,n_embd)\n",
|
643 |
+
"\n",
|
644 |
+
" x =self.blocks((x, encoder_output))\n",
|
645 |
+
" x = self.ln_f(x[0]) # (B,T,C)\n",
|
646 |
+
" logits = self.lm_head(x) # (B,T,output_vocab_size)\n",
|
647 |
+
"\n",
|
648 |
+
" if targets is None:\n",
|
649 |
+
" loss = None\n",
|
650 |
+
" else:\n",
|
651 |
+
" B, T, C = logits.shape\n",
|
652 |
+
" temp_logits = logits.view(B*T, C)\n",
|
653 |
+
" targets = targets.reshape(B*T)\n",
|
654 |
+
"\n",
|
655 |
+
" loss = F.cross_entropy(temp_logits, targets.long())\n",
|
656 |
+
"\n",
|
657 |
+
" # print(logits)\n",
|
658 |
+
" # out = torch.argmax(logits)\n",
|
659 |
+
"\n",
|
660 |
+
" return logits, loss\n",
|
661 |
+
"\n"
|
662 |
+
]
|
663 |
+
},
|
664 |
+
{
|
665 |
+
"cell_type": "markdown",
|
666 |
+
"metadata": {
|
667 |
+
"id": "EBjmsIcklM8Y"
|
668 |
+
},
|
669 |
+
"source": [
|
670 |
+
"# Training Time"
|
671 |
+
]
|
672 |
+
},
|
673 |
+
{
|
674 |
+
"cell_type": "markdown",
|
675 |
+
"metadata": {
|
676 |
+
"id": "lLfHEDk8FNfY"
|
677 |
+
},
|
678 |
+
"source": [
|
679 |
+
"## sweep config"
|
680 |
+
]
|
681 |
+
},
|
682 |
+
{
|
683 |
+
"cell_type": "code",
|
684 |
+
"execution_count": null,
|
685 |
+
"metadata": {
|
686 |
+
"execution": {
|
687 |
+
"iopub.execute_input": "2024-04-04T14:54:15.308213Z",
|
688 |
+
"iopub.status.busy": "2024-04-04T14:54:15.307981Z",
|
689 |
+
"iopub.status.idle": "2024-04-04T14:54:15.319933Z",
|
690 |
+
"shell.execute_reply": "2024-04-04T14:54:15.319070Z",
|
691 |
+
"shell.execute_reply.started": "2024-04-04T14:54:15.308192Z"
|
692 |
+
},
|
693 |
+
"id": "nDcRZmb80msE",
|
694 |
+
"trusted": true
|
695 |
+
},
|
696 |
+
"outputs": [],
|
697 |
+
"source": [
|
698 |
+
"# Define sweep config\n",
|
699 |
+
"sweep_configuration = {\n",
|
700 |
+
" \"method\": \"bayes\",\n",
|
701 |
+
" \"name\": \"sweep\",\n",
|
702 |
+
" \"metric\": {\"goal\": \"maximize\", \"name\": \"val_acc\"},\n",
|
703 |
+
" \"parameters\": {\n",
|
704 |
+
" \"batch_size\": {\"values\": [64, 128, 256]},\n",
|
705 |
+
" \"epochs\": {\"values\": [20, 40, 50, 100]},\n",
|
706 |
+
" \"lr\": {\"max\": 0.1, \"min\": 0.0001},\n",
|
707 |
+
" \"n_embd\": {\"values\": [16, 32, 64]},\n",
|
708 |
+
" \"n_head\": {\"values\": [2, 4, 8]},\n",
|
709 |
+
" \"n_layers\": {\"values\": [4, 6, 8]},\n",
|
710 |
+
" \"dropout\": {\"values\": [0, .1, .2, .3]}\n",
|
711 |
+
" },\n",
|
712 |
+
"}\n",
|
713 |
+
"\n",
|
714 |
+
"sweep_id = wandb.sweep(sweep=sweep_configuration, project=\"Tranliteration-Tranformers\")"
|
715 |
+
]
|
716 |
+
},
|
717 |
+
{
|
718 |
+
"cell_type": "code",
|
719 |
+
"execution_count": null,
|
720 |
+
"metadata": {
|
721 |
+
"execution": {
|
722 |
+
"iopub.execute_input": "2024-04-04T14:54:15.325199Z",
|
723 |
+
"iopub.status.busy": "2024-04-04T14:54:15.324615Z",
|
724 |
+
"iopub.status.idle": "2024-04-04T14:54:15.330172Z",
|
725 |
+
"shell.execute_reply": "2024-04-04T14:54:15.329301Z",
|
726 |
+
"shell.execute_reply.started": "2024-04-04T14:54:15.325168Z"
|
727 |
+
},
|
728 |
+
"id": "9CguGUG5_1NL",
|
729 |
+
"trusted": true
|
730 |
+
},
|
731 |
+
"outputs": [],
|
732 |
+
"source": [
|
733 |
+
"# wandb.sweep_cancel(sweep_id)\n",
|
734 |
+
"# wandb.finish()\n",
|
735 |
+
"# wandb.run.cancel()"
|
736 |
+
]
|
737 |
+
},
|
738 |
+
{
|
739 |
+
"cell_type": "markdown",
|
740 |
+
"metadata": {
|
741 |
+
"id": "d5T58TQRECbZ"
|
742 |
+
},
|
743 |
+
"source": [
|
744 |
+
"## train function"
|
745 |
+
]
|
746 |
+
},
|
747 |
+
{
|
748 |
+
"cell_type": "code",
|
749 |
+
"execution_count": null,
|
750 |
+
"metadata": {
|
751 |
+
"execution": {
|
752 |
+
"iopub.execute_input": "2024-04-04T14:54:15.331837Z",
|
753 |
+
"iopub.status.busy": "2024-04-04T14:54:15.331538Z",
|
754 |
+
"iopub.status.idle": "2024-04-04T14:54:15.351924Z",
|
755 |
+
"shell.execute_reply": "2024-04-04T14:54:15.351027Z",
|
756 |
+
"shell.execute_reply.started": "2024-04-04T14:54:15.331810Z"
|
757 |
+
},
|
758 |
+
"id": "3GWnCggNFLs3",
|
759 |
+
"trusted": true
|
760 |
+
},
|
761 |
+
"outputs": [],
|
762 |
+
"source": [
|
763 |
+
"def train():\n",
|
764 |
+
" run = wandb.init()\n",
|
765 |
+
"\n",
|
766 |
+
" n_embd = wandb.config.n_embd\n",
|
767 |
+
" n_head = wandb.config.n_head\n",
|
768 |
+
" n_layers = wandb.config.n_layers\n",
|
769 |
+
" dropout = wandb.config.dropout\n",
|
770 |
+
" epochs = wandb.config.epochs\n",
|
771 |
+
" batch_size = wandb.config.batch_size\n",
|
772 |
+
" learning_rate = wandb.config.lr\n",
|
773 |
+
"\n",
|
774 |
+
"\n",
|
775 |
+
" encoder = Encoder(n_embd, n_head, n_layers, dropout)\n",
|
776 |
+
" decoder = Decoder(n_embd, n_head, n_layers, dropout)\n",
|
777 |
+
" encoder.to(device)\n",
|
778 |
+
" decoder.to(device)\n",
|
779 |
+
"\n",
|
780 |
+
" train_losses, train_accuracies, val_losses, val_accuracies = [], [], [], []\n",
|
781 |
+
"\n",
|
782 |
+
" # print the number of parameters in the model\n",
|
783 |
+
" print(sum([p.numel() for p in encoder.parameters()] + [p.numel() for p in decoder.parameters()])/1e3, 'K model parameters')\n",
|
784 |
+
"\n",
|
785 |
+
" train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)\n",
|
786 |
+
" val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True)\n",
|
787 |
+
"\n",
|
788 |
+
" # create a PyTorch optimizer\n",
|
789 |
+
" encoder_optimizer = torch.optim.AdamW(encoder.parameters(), lr=learning_rate)\n",
|
790 |
+
" decoder_optimizer = torch.optim.AdamW(decoder.parameters(), lr=learning_rate)\n",
|
791 |
+
"\n",
|
792 |
+
"# print('Step | Training Loss | Validation Loss | Training Accuracy % | Validation Accuracy %')\n",
|
793 |
+
"\n",
|
794 |
+
" least_error = float('inf')\n",
|
795 |
+
" patience = 20 # The number of epochs without improvement to wait before stopping\n",
|
796 |
+
" no_improvement = 0\n",
|
797 |
+
"\n",
|
798 |
+
" for i in range(epochs):\n",
|
799 |
+
" running_loss = 0.0\n",
|
800 |
+
" train_correct = 0\n",
|
801 |
+
"\n",
|
802 |
+
" encoder.train()\n",
|
803 |
+
" decoder.train()\n",
|
804 |
+
"\n",
|
805 |
+
" for j,(train_x,train_y) in enumerate(train_loader):\n",
|
806 |
+
" train_x = train_x.to(device)\n",
|
807 |
+
" train_y = train_y.to(device)\n",
|
808 |
+
"\n",
|
809 |
+
" encoder_optimizer.zero_grad(set_to_none=True)\n",
|
810 |
+
" decoder_optimizer.zero_grad(set_to_none=True)\n",
|
811 |
+
"\n",
|
812 |
+
" encoder_output = encoder(train_x)\n",
|
813 |
+
" logits, loss = decoder(train_y[:, :-1], encoder_output, train_y[:, 1:])\n",
|
814 |
+
"\n",
|
815 |
+
" encoder_optimizer.zero_grad(set_to_none=True)\n",
|
816 |
+
" decoder_optimizer.zero_grad(set_to_none=True)\n",
|
817 |
+
" loss.backward()\n",
|
818 |
+
" encoder_optimizer.step()\n",
|
819 |
+
" decoder_optimizer.step()\n",
|
820 |
+
"\n",
|
821 |
+
" running_loss += loss\n",
|
822 |
+
" pred_decoder_output = torch.argmax(logits, dim=-1)\n",
|
823 |
+
" # print(pred_decoder_output, \" target: \", train_y[:, 1:])\n",
|
824 |
+
" train_correct += (pred_decoder_output == train_y[:, 1:]).sum().item()\n",
|
825 |
+
"\n",
|
826 |
+
"\n",
|
827 |
+
" ## validation code\n",
|
828 |
+
" running_loss_val, val_correct = 0, 0\n",
|
829 |
+
" encoder.eval()\n",
|
830 |
+
" decoder.eval()\n",
|
831 |
+
" for j,(val_x,val_y) in enumerate(val_loader):\n",
|
832 |
+
" val_x = val_x.to(device)\n",
|
833 |
+
" val_y = val_y.to(device)\n",
|
834 |
+
"\n",
|
835 |
+
" encoder_output = encoder(val_x)\n",
|
836 |
+
" logits, loss = decoder(val_y[:, :-1], encoder_output, val_y[:, 1:])\n",
|
837 |
+
"\n",
|
838 |
+
" running_loss_val += loss\n",
|
839 |
+
" pred_decoder_output = torch.argmax(logits, dim=-1)\n",
|
840 |
+
" val_correct += torch.sum(pred_decoder_output == val_y[:, 1:])\n",
|
841 |
+
"\n",
|
842 |
+
"\n",
|
843 |
+
" if running_loss_val < least_error:\n",
|
844 |
+
" least_error = running_loss_val\n",
|
845 |
+
" no_improvement = 0\n",
|
846 |
+
" else:\n",
|
847 |
+
" no_improvement += 1\n",
|
848 |
+
"\n",
|
849 |
+
" if no_improvement >= patience:\n",
|
850 |
+
" print(f\"Early stopping at epoch {i}\")\n",
|
851 |
+
" break\n",
|
852 |
+
"\n",
|
853 |
+
" wandb.log(\n",
|
854 |
+
" {\n",
|
855 |
+
" \"train_loss\": running_loss / len(train_data),\n",
|
856 |
+
" \"val_loss\": (running_loss_val/len(val_data)),\n",
|
857 |
+
" \"train_acc\": ((train_correct*100) / (len(train_data)* (decoder_block_size-1))),\n",
|
858 |
+
" \"val_acc\": ((val_correct*100)/(len(val_data)* (decoder_block_size-1))),\n",
|
859 |
+
" }\n",
|
860 |
+
" )"
|
861 |
+
]
|
862 |
+
},
|
863 |
+
{
|
864 |
+
"cell_type": "markdown",
|
865 |
+
"metadata": {
|
866 |
+
"id": "CxzRR9cjEGDm"
|
867 |
+
},
|
868 |
+
"source": [
|
869 |
+
"## run sweep"
|
870 |
+
]
|
871 |
+
},
|
872 |
+
{
|
873 |
+
"cell_type": "code",
|
874 |
+
"execution_count": null,
|
875 |
+
"metadata": {
|
876 |
+
"colab": {
|
877 |
+
"base_uri": "https://localhost:8080/",
|
878 |
+
"height": 295,
|
879 |
+
"referenced_widgets": [
|
880 |
+
""
|
881 |
+
]
|
882 |
+
},
|
883 |
+
"execution": {
|
884 |
+
"iopub.execute_input": "2024-04-04T14:54:15.353688Z",
|
885 |
+
"iopub.status.busy": "2024-04-04T14:54:15.353125Z"
|
886 |
+
},
|
887 |
+
"id": "u_QFbYe32t7r",
|
888 |
+
"outputId": "97153eab-b36f-454b-9fed-53ae0287aee1",
|
889 |
+
"trusted": true
|
890 |
+
},
|
891 |
+
"outputs": [
|
892 |
+
{
|
893 |
+
"name": "stderr",
|
894 |
+
"output_type": "stream",
|
895 |
+
"text": [
|
896 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Agent Starting Run: dcco6zur with config:\n",
|
897 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \tbatch_size: 64\n",
|
898 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \tdropout: 0\n",
|
899 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \tepochs: 50\n",
|
900 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \tlr: 0.0003\n",
|
901 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \tn_embd: 64\n",
|
902 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \tn_head: 4\n",
|
903 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \tn_layers: 6\n",
|
904 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mcs22m062\u001b[0m (\u001b[33miitmadras\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
|
905 |
+
]
|
906 |
+
},
|
907 |
+
{
|
908 |
+
"data": {
|
909 |
+
"text/html": [
|
910 |
+
"wandb version 0.16.6 is available! To upgrade, please run:\n",
|
911 |
+
" $ pip install wandb --upgrade"
|
912 |
+
],
|
913 |
+
"text/plain": [
|
914 |
+
"<IPython.core.display.HTML object>"
|
915 |
+
]
|
916 |
+
},
|
917 |
+
"metadata": {},
|
918 |
+
"output_type": "display_data"
|
919 |
+
},
|
920 |
+
{
|
921 |
+
"data": {
|
922 |
+
"text/html": [
|
923 |
+
"Tracking run with wandb version 0.16.4"
|
924 |
+
],
|
925 |
+
"text/plain": [
|
926 |
+
"<IPython.core.display.HTML object>"
|
927 |
+
]
|
928 |
+
},
|
929 |
+
"metadata": {},
|
930 |
+
"output_type": "display_data"
|
931 |
+
},
|
932 |
+
{
|
933 |
+
"data": {
|
934 |
+
"text/html": [
|
935 |
+
"Run data is saved locally in <code>/kaggle/working/wandb/run-20240404_145417-dcco6zur</code>"
|
936 |
+
],
|
937 |
+
"text/plain": [
|
938 |
+
"<IPython.core.display.HTML object>"
|
939 |
+
]
|
940 |
+
},
|
941 |
+
"metadata": {},
|
942 |
+
"output_type": "display_data"
|
943 |
+
},
|
944 |
+
{
|
945 |
+
"data": {
|
946 |
+
"text/html": [
|
947 |
+
"Syncing run <strong><a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/dcco6zur' target=\"_blank\">eager-sweep-2</a></strong> to <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>Sweep page: <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161</a>"
|
948 |
+
],
|
949 |
+
"text/plain": [
|
950 |
+
"<IPython.core.display.HTML object>"
|
951 |
+
]
|
952 |
+
},
|
953 |
+
"metadata": {},
|
954 |
+
"output_type": "display_data"
|
955 |
+
},
|
956 |
+
{
|
957 |
+
"data": {
|
958 |
+
"text/html": [
|
959 |
+
" View project at <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers</a>"
|
960 |
+
],
|
961 |
+
"text/plain": [
|
962 |
+
"<IPython.core.display.HTML object>"
|
963 |
+
]
|
964 |
+
},
|
965 |
+
"metadata": {},
|
966 |
+
"output_type": "display_data"
|
967 |
+
},
|
968 |
+
{
|
969 |
+
"data": {
|
970 |
+
"text/html": [
|
971 |
+
" View sweep at <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161</a>"
|
972 |
+
],
|
973 |
+
"text/plain": [
|
974 |
+
"<IPython.core.display.HTML object>"
|
975 |
+
]
|
976 |
+
},
|
977 |
+
"metadata": {},
|
978 |
+
"output_type": "display_data"
|
979 |
+
},
|
980 |
+
{
|
981 |
+
"data": {
|
982 |
+
"text/html": [
|
983 |
+
" View run at <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/dcco6zur' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/dcco6zur</a>"
|
984 |
+
],
|
985 |
+
"text/plain": [
|
986 |
+
"<IPython.core.display.HTML object>"
|
987 |
+
]
|
988 |
+
},
|
989 |
+
"metadata": {},
|
990 |
+
"output_type": "display_data"
|
991 |
+
},
|
992 |
+
{
|
993 |
+
"name": "stdout",
|
994 |
+
"output_type": "stream",
|
995 |
+
"text": [
|
996 |
+
"710.915 K model parameters\n",
|
997 |
+
"Early stopping at epoch 32\n"
|
998 |
+
]
|
999 |
+
},
|
1000 |
+
{
|
1001 |
+
"data": {
|
1002 |
+
"application/vnd.jupyter.widget-view+json": {
|
1003 |
+
"model_id": "",
|
1004 |
+
"version_major": 2,
|
1005 |
+
"version_minor": 0
|
1006 |
+
},
|
1007 |
+
"text/plain": [
|
1008 |
+
"VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\\r'), FloatProgress(value=1.0, max=1.0)))"
|
1009 |
+
]
|
1010 |
+
},
|
1011 |
+
"metadata": {},
|
1012 |
+
"output_type": "display_data"
|
1013 |
+
},
|
1014 |
+
{
|
1015 |
+
"data": {
|
1016 |
+
"text/html": [
|
1017 |
+
"<style>\n",
|
1018 |
+
" table.wandb td:nth-child(1) { padding: 0 10px; text-align: left ; width: auto;} td:nth-child(2) {text-align: left ; width: 100%}\n",
|
1019 |
+
" .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n",
|
1020 |
+
" .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n",
|
1021 |
+
" </style>\n",
|
1022 |
+
"<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>train_acc</td><td>▁▅▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇███████████████</td></tr><tr><td>train_loss</td><td>█▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>val_acc</td><td>▁▅▆▆▇▇▇▇▇▇██████████████████████</td></tr><tr><td>val_loss</td><td>█▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>train_acc</td><td>97.8125</td></tr><tr><td>train_loss</td><td>0.00096</td></tr><tr><td>val_acc</td><td>95.29739</td></tr><tr><td>val_loss</td><td>0.00286</td></tr></table><br/></div></div>"
|
1023 |
+
],
|
1024 |
+
"text/plain": [
|
1025 |
+
"<IPython.core.display.HTML object>"
|
1026 |
+
]
|
1027 |
+
},
|
1028 |
+
"metadata": {},
|
1029 |
+
"output_type": "display_data"
|
1030 |
+
},
|
1031 |
+
{
|
1032 |
+
"data": {
|
1033 |
+
"text/html": [
|
1034 |
+
" View run <strong style=\"color:#cdcd00\">eager-sweep-2</strong> at: <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/dcco6zur' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/dcco6zur</a><br/>Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
|
1035 |
+
],
|
1036 |
+
"text/plain": [
|
1037 |
+
"<IPython.core.display.HTML object>"
|
1038 |
+
]
|
1039 |
+
},
|
1040 |
+
"metadata": {},
|
1041 |
+
"output_type": "display_data"
|
1042 |
+
},
|
1043 |
+
{
|
1044 |
+
"data": {
|
1045 |
+
"text/html": [
|
1046 |
+
"Find logs at: <code>./wandb/run-20240404_145417-dcco6zur/logs</code>"
|
1047 |
+
],
|
1048 |
+
"text/plain": [
|
1049 |
+
"<IPython.core.display.HTML object>"
|
1050 |
+
]
|
1051 |
+
},
|
1052 |
+
"metadata": {},
|
1053 |
+
"output_type": "display_data"
|
1054 |
+
},
|
1055 |
+
{
|
1056 |
+
"name": "stderr",
|
1057 |
+
"output_type": "stream",
|
1058 |
+
"text": [
|
1059 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Agent Starting Run: 4qb2bmi8 with config:\n",
|
1060 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \tbatch_size: 128\n",
|
1061 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \tdropout: 0.1\n",
|
1062 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \tepochs: 20\n",
|
1063 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \tlr: 0.03\n",
|
1064 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \tn_embd: 16\n",
|
1065 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \tn_head: 4\n",
|
1066 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \tn_layers: 6\n"
|
1067 |
+
]
|
1068 |
+
},
|
1069 |
+
{
|
1070 |
+
"data": {
|
1071 |
+
"text/html": [
|
1072 |
+
"wandb version 0.16.6 is available! To upgrade, please run:\n",
|
1073 |
+
" $ pip install wandb --upgrade"
|
1074 |
+
],
|
1075 |
+
"text/plain": [
|
1076 |
+
"<IPython.core.display.HTML object>"
|
1077 |
+
]
|
1078 |
+
},
|
1079 |
+
"metadata": {},
|
1080 |
+
"output_type": "display_data"
|
1081 |
+
},
|
1082 |
+
{
|
1083 |
+
"data": {
|
1084 |
+
"text/html": [
|
1085 |
+
"Tracking run with wandb version 0.16.4"
|
1086 |
+
],
|
1087 |
+
"text/plain": [
|
1088 |
+
"<IPython.core.display.HTML object>"
|
1089 |
+
]
|
1090 |
+
},
|
1091 |
+
"metadata": {},
|
1092 |
+
"output_type": "display_data"
|
1093 |
+
},
|
1094 |
+
{
|
1095 |
+
"data": {
|
1096 |
+
"text/html": [
|
1097 |
+
"Run data is saved locally in <code>/kaggle/working/wandb/run-20240404_153243-4qb2bmi8</code>"
|
1098 |
+
],
|
1099 |
+
"text/plain": [
|
1100 |
+
"<IPython.core.display.HTML object>"
|
1101 |
+
]
|
1102 |
+
},
|
1103 |
+
"metadata": {},
|
1104 |
+
"output_type": "display_data"
|
1105 |
+
},
|
1106 |
+
{
|
1107 |
+
"data": {
|
1108 |
+
"text/html": [
|
1109 |
+
"Syncing run <strong><a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/4qb2bmi8' target=\"_blank\">peach-sweep-3</a></strong> to <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>Sweep page: <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161</a>"
|
1110 |
+
],
|
1111 |
+
"text/plain": [
|
1112 |
+
"<IPython.core.display.HTML object>"
|
1113 |
+
]
|
1114 |
+
},
|
1115 |
+
"metadata": {},
|
1116 |
+
"output_type": "display_data"
|
1117 |
+
},
|
1118 |
+
{
|
1119 |
+
"data": {
|
1120 |
+
"text/html": [
|
1121 |
+
" View project at <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers</a>"
|
1122 |
+
],
|
1123 |
+
"text/plain": [
|
1124 |
+
"<IPython.core.display.HTML object>"
|
1125 |
+
]
|
1126 |
+
},
|
1127 |
+
"metadata": {},
|
1128 |
+
"output_type": "display_data"
|
1129 |
+
},
|
1130 |
+
{
|
1131 |
+
"data": {
|
1132 |
+
"text/html": [
|
1133 |
+
" View sweep at <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161</a>"
|
1134 |
+
],
|
1135 |
+
"text/plain": [
|
1136 |
+
"<IPython.core.display.HTML object>"
|
1137 |
+
]
|
1138 |
+
},
|
1139 |
+
"metadata": {},
|
1140 |
+
"output_type": "display_data"
|
1141 |
+
},
|
1142 |
+
{
|
1143 |
+
"data": {
|
1144 |
+
"text/html": [
|
1145 |
+
" View run at <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/4qb2bmi8' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/4qb2bmi8</a>"
|
1146 |
+
],
|
1147 |
+
"text/plain": [
|
1148 |
+
"<IPython.core.display.HTML object>"
|
1149 |
+
]
|
1150 |
+
},
|
1151 |
+
"metadata": {},
|
1152 |
+
"output_type": "display_data"
|
1153 |
+
},
|
1154 |
+
{
|
1155 |
+
"name": "stdout",
|
1156 |
+
"output_type": "stream",
|
1157 |
+
"text": [
|
1158 |
+
"48.755 K model parameters\n"
|
1159 |
+
]
|
1160 |
+
},
|
1161 |
+
{
|
1162 |
+
"data": {
|
1163 |
+
"application/vnd.jupyter.widget-view+json": {
|
1164 |
+
"model_id": "",
|
1165 |
+
"version_major": 2,
|
1166 |
+
"version_minor": 0
|
1167 |
+
},
|
1168 |
+
"text/plain": [
|
1169 |
+
"VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\\r'), FloatProgress(value=1.0, max=1.0)))"
|
1170 |
+
]
|
1171 |
+
},
|
1172 |
+
"metadata": {},
|
1173 |
+
"output_type": "display_data"
|
1174 |
+
},
|
1175 |
+
{
|
1176 |
+
"data": {
|
1177 |
+
"text/html": [
|
1178 |
+
"<style>\n",
|
1179 |
+
" table.wandb td:nth-child(1) { padding: 0 10px; text-align: left ; width: auto;} td:nth-child(2) {text-align: left ; width: 100%}\n",
|
1180 |
+
" .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n",
|
1181 |
+
" .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n",
|
1182 |
+
" </style>\n",
|
1183 |
+
"<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>train_acc</td><td>▁▅▇▇▇███████████████</td></tr><tr><td>train_loss</td><td>█▄▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>val_acc</td><td>▁▆▇▇▇███████████████</td></tr><tr><td>val_loss</td><td>█▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>train_acc</td><td>89.37686</td></tr><tr><td>train_loss</td><td>0.00256</td></tr><tr><td>val_acc</td><td>92.66765</td></tr><tr><td>val_loss</td><td>0.0018</td></tr></table><br/></div></div>"
|
1184 |
+
],
|
1185 |
+
"text/plain": [
|
1186 |
+
"<IPython.core.display.HTML object>"
|
1187 |
+
]
|
1188 |
+
},
|
1189 |
+
"metadata": {},
|
1190 |
+
"output_type": "display_data"
|
1191 |
+
},
|
1192 |
+
{
|
1193 |
+
"data": {
|
1194 |
+
"text/html": [
|
1195 |
+
" View run <strong style=\"color:#cdcd00\">peach-sweep-3</strong> at: <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/4qb2bmi8' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/4qb2bmi8</a><br/>Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
|
1196 |
+
],
|
1197 |
+
"text/plain": [
|
1198 |
+
"<IPython.core.display.HTML object>"
|
1199 |
+
]
|
1200 |
+
},
|
1201 |
+
"metadata": {},
|
1202 |
+
"output_type": "display_data"
|
1203 |
+
},
|
1204 |
+
{
|
1205 |
+
"data": {
|
1206 |
+
"text/html": [
|
1207 |
+
"Find logs at: <code>./wandb/run-20240404_153243-4qb2bmi8/logs</code>"
|
1208 |
+
],
|
1209 |
+
"text/plain": [
|
1210 |
+
"<IPython.core.display.HTML object>"
|
1211 |
+
]
|
1212 |
+
},
|
1213 |
+
"metadata": {},
|
1214 |
+
"output_type": "display_data"
|
1215 |
+
},
|
1216 |
+
{
|
1217 |
+
"name": "stderr",
|
1218 |
+
"output_type": "stream",
|
1219 |
+
"text": [
|
1220 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Agent Starting Run: gtz48xe5 with config:\n",
|
1221 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \tbatch_size: 32\n",
|
1222 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \tdropout: 0\n",
|
1223 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \tepochs: 30\n",
|
1224 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \tlr: 0.01\n",
|
1225 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \tn_embd: 16\n",
|
1226 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \tn_head: 4\n",
|
1227 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \tn_layers: 4\n"
|
1228 |
+
]
|
1229 |
+
},
|
1230 |
+
{
|
1231 |
+
"data": {
|
1232 |
+
"text/html": [
|
1233 |
+
"wandb version 0.16.6 is available! To upgrade, please run:\n",
|
1234 |
+
" $ pip install wandb --upgrade"
|
1235 |
+
],
|
1236 |
+
"text/plain": [
|
1237 |
+
"<IPython.core.display.HTML object>"
|
1238 |
+
]
|
1239 |
+
},
|
1240 |
+
"metadata": {},
|
1241 |
+
"output_type": "display_data"
|
1242 |
+
},
|
1243 |
+
{
|
1244 |
+
"data": {
|
1245 |
+
"text/html": [
|
1246 |
+
"Tracking run with wandb version 0.16.4"
|
1247 |
+
],
|
1248 |
+
"text/plain": [
|
1249 |
+
"<IPython.core.display.HTML object>"
|
1250 |
+
]
|
1251 |
+
},
|
1252 |
+
"metadata": {},
|
1253 |
+
"output_type": "display_data"
|
1254 |
+
},
|
1255 |
+
{
|
1256 |
+
"data": {
|
1257 |
+
"text/html": [
|
1258 |
+
"Run data is saved locally in <code>/kaggle/working/wandb/run-20240404_154533-gtz48xe5</code>"
|
1259 |
+
],
|
1260 |
+
"text/plain": [
|
1261 |
+
"<IPython.core.display.HTML object>"
|
1262 |
+
]
|
1263 |
+
},
|
1264 |
+
"metadata": {},
|
1265 |
+
"output_type": "display_data"
|
1266 |
+
},
|
1267 |
+
{
|
1268 |
+
"data": {
|
1269 |
+
"text/html": [
|
1270 |
+
"Syncing run <strong><a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/gtz48xe5' target=\"_blank\">cerulean-sweep-4</a></strong> to <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>Sweep page: <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161</a>"
|
1271 |
+
],
|
1272 |
+
"text/plain": [
|
1273 |
+
"<IPython.core.display.HTML object>"
|
1274 |
+
]
|
1275 |
+
},
|
1276 |
+
"metadata": {},
|
1277 |
+
"output_type": "display_data"
|
1278 |
+
},
|
1279 |
+
{
|
1280 |
+
"data": {
|
1281 |
+
"text/html": [
|
1282 |
+
" View project at <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers</a>"
|
1283 |
+
],
|
1284 |
+
"text/plain": [
|
1285 |
+
"<IPython.core.display.HTML object>"
|
1286 |
+
]
|
1287 |
+
},
|
1288 |
+
"metadata": {},
|
1289 |
+
"output_type": "display_data"
|
1290 |
+
},
|
1291 |
+
{
|
1292 |
+
"data": {
|
1293 |
+
"text/html": [
|
1294 |
+
" View sweep at <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161</a>"
|
1295 |
+
],
|
1296 |
+
"text/plain": [
|
1297 |
+
"<IPython.core.display.HTML object>"
|
1298 |
+
]
|
1299 |
+
},
|
1300 |
+
"metadata": {},
|
1301 |
+
"output_type": "display_data"
|
1302 |
+
},
|
1303 |
+
{
|
1304 |
+
"data": {
|
1305 |
+
"text/html": [
|
1306 |
+
" View run at <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/gtz48xe5' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/gtz48xe5</a>"
|
1307 |
+
],
|
1308 |
+
"text/plain": [
|
1309 |
+
"<IPython.core.display.HTML object>"
|
1310 |
+
]
|
1311 |
+
},
|
1312 |
+
"metadata": {},
|
1313 |
+
"output_type": "display_data"
|
1314 |
+
},
|
1315 |
+
{
|
1316 |
+
"name": "stdout",
|
1317 |
+
"output_type": "stream",
|
1318 |
+
"text": [
|
1319 |
+
"33.683 K model parameters\n"
|
1320 |
+
]
|
1321 |
+
},
|
1322 |
+
{
|
1323 |
+
"data": {
|
1324 |
+
"application/vnd.jupyter.widget-view+json": {
|
1325 |
+
"model_id": "",
|
1326 |
+
"version_major": 2,
|
1327 |
+
"version_minor": 0
|
1328 |
+
},
|
1329 |
+
"text/plain": [
|
1330 |
+
"VBox(children=(Label(value='0.001 MB of 0.047 MB uploaded\\r'), FloatProgress(value=0.028017589156043247, max=1…"
|
1331 |
+
]
|
1332 |
+
},
|
1333 |
+
"metadata": {},
|
1334 |
+
"output_type": "display_data"
|
1335 |
+
},
|
1336 |
+
{
|
1337 |
+
"data": {
|
1338 |
+
"text/html": [
|
1339 |
+
"<style>\n",
|
1340 |
+
" table.wandb td:nth-child(1) { padding: 0 10px; text-align: left ; width: auto;} td:nth-child(2) {text-align: left ; width: 100%}\n",
|
1341 |
+
" .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n",
|
1342 |
+
" .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n",
|
1343 |
+
" </style>\n",
|
1344 |
+
"<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>train_acc</td><td>▁▆▆▇▇▇▇▇▇█████████████████████</td></tr><tr><td>train_loss</td><td>█▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>val_acc</td><td>▁▃▄▅▅▆▇▆▇▆▇▇▆▆▆▇▆▇█▇▇▇▇██▇▇▇█▇</td></tr><tr><td>val_loss</td><td>█▆▅▃▃▃▂▂▂▂▂▂▃▂▂▂▃▂▂▂▂▂▂▁▁▂▂▂▁▂</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>train_acc</td><td>92.21615</td></tr><tr><td>train_loss</td><td>0.00725</td></tr><tr><td>val_acc</td><td>93.30009</td></tr><tr><td>val_loss</td><td>0.00663</td></tr></table><br/></div></div>"
|
1345 |
+
],
|
1346 |
+
"text/plain": [
|
1347 |
+
"<IPython.core.display.HTML object>"
|
1348 |
+
]
|
1349 |
+
},
|
1350 |
+
"metadata": {},
|
1351 |
+
"output_type": "display_data"
|
1352 |
+
},
|
1353 |
+
{
|
1354 |
+
"data": {
|
1355 |
+
"text/html": [
|
1356 |
+
" View run <strong style=\"color:#cdcd00\">cerulean-sweep-4</strong> at: <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/gtz48xe5' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/gtz48xe5</a><br/>Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
|
1357 |
+
],
|
1358 |
+
"text/plain": [
|
1359 |
+
"<IPython.core.display.HTML object>"
|
1360 |
+
]
|
1361 |
+
},
|
1362 |
+
"metadata": {},
|
1363 |
+
"output_type": "display_data"
|
1364 |
+
},
|
1365 |
+
{
|
1366 |
+
"data": {
|
1367 |
+
"text/html": [
|
1368 |
+
"Find logs at: <code>./wandb/run-20240404_154533-gtz48xe5/logs</code>"
|
1369 |
+
],
|
1370 |
+
"text/plain": [
|
1371 |
+
"<IPython.core.display.HTML object>"
|
1372 |
+
]
|
1373 |
+
},
|
1374 |
+
"metadata": {},
|
1375 |
+
"output_type": "display_data"
|
1376 |
+
},
|
1377 |
+
{
|
1378 |
+
"name": "stderr",
|
1379 |
+
"output_type": "stream",
|
1380 |
+
"text": [
|
1381 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Agent Starting Run: aoy7fr9k with config:\n",
|
1382 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \tbatch_size: 256\n",
|
1383 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \tdropout: 0.1\n",
|
1384 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \tepochs: 30\n",
|
1385 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \tlr: 0.0003\n",
|
1386 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \tn_embd: 64\n",
|
1387 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \tn_head: 8\n",
|
1388 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \tn_layers: 4\n"
|
1389 |
+
]
|
1390 |
+
},
|
1391 |
+
{
|
1392 |
+
"data": {
|
1393 |
+
"text/html": [
|
1394 |
+
"wandb version 0.16.6 is available! To upgrade, please run:\n",
|
1395 |
+
" $ pip install wandb --upgrade"
|
1396 |
+
],
|
1397 |
+
"text/plain": [
|
1398 |
+
"<IPython.core.display.HTML object>"
|
1399 |
+
]
|
1400 |
+
},
|
1401 |
+
"metadata": {},
|
1402 |
+
"output_type": "display_data"
|
1403 |
+
},
|
1404 |
+
{
|
1405 |
+
"data": {
|
1406 |
+
"text/html": [
|
1407 |
+
"Tracking run with wandb version 0.16.4"
|
1408 |
+
],
|
1409 |
+
"text/plain": [
|
1410 |
+
"<IPython.core.display.HTML object>"
|
1411 |
+
]
|
1412 |
+
},
|
1413 |
+
"metadata": {},
|
1414 |
+
"output_type": "display_data"
|
1415 |
+
},
|
1416 |
+
{
|
1417 |
+
"data": {
|
1418 |
+
"text/html": [
|
1419 |
+
"Run data is saved locally in <code>/kaggle/working/wandb/run-20240404_163029-aoy7fr9k</code>"
|
1420 |
+
],
|
1421 |
+
"text/plain": [
|
1422 |
+
"<IPython.core.display.HTML object>"
|
1423 |
+
]
|
1424 |
+
},
|
1425 |
+
"metadata": {},
|
1426 |
+
"output_type": "display_data"
|
1427 |
+
},
|
1428 |
+
{
|
1429 |
+
"data": {
|
1430 |
+
"text/html": [
|
1431 |
+
"Syncing run <strong><a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/aoy7fr9k' target=\"_blank\">warm-sweep-6</a></strong> to <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>Sweep page: <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161</a>"
|
1432 |
+
],
|
1433 |
+
"text/plain": [
|
1434 |
+
"<IPython.core.display.HTML object>"
|
1435 |
+
]
|
1436 |
+
},
|
1437 |
+
"metadata": {},
|
1438 |
+
"output_type": "display_data"
|
1439 |
+
},
|
1440 |
+
{
|
1441 |
+
"data": {
|
1442 |
+
"text/html": [
|
1443 |
+
" View project at <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers</a>"
|
1444 |
+
],
|
1445 |
+
"text/plain": [
|
1446 |
+
"<IPython.core.display.HTML object>"
|
1447 |
+
]
|
1448 |
+
},
|
1449 |
+
"metadata": {},
|
1450 |
+
"output_type": "display_data"
|
1451 |
+
},
|
1452 |
+
{
|
1453 |
+
"data": {
|
1454 |
+
"text/html": [
|
1455 |
+
" View sweep at <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161</a>"
|
1456 |
+
],
|
1457 |
+
"text/plain": [
|
1458 |
+
"<IPython.core.display.HTML object>"
|
1459 |
+
]
|
1460 |
+
},
|
1461 |
+
"metadata": {},
|
1462 |
+
"output_type": "display_data"
|
1463 |
+
},
|
1464 |
+
{
|
1465 |
+
"data": {
|
1466 |
+
"text/html": [
|
1467 |
+
" View run at <a href='https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/aoy7fr9k' target=\"_blank\">https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/aoy7fr9k</a>"
|
1468 |
+
],
|
1469 |
+
"text/plain": [
|
1470 |
+
"<IPython.core.display.HTML object>"
|
1471 |
+
]
|
1472 |
+
},
|
1473 |
+
"metadata": {},
|
1474 |
+
"output_type": "display_data"
|
1475 |
+
},
|
1476 |
+
{
|
1477 |
+
"name": "stdout",
|
1478 |
+
"output_type": "stream",
|
1479 |
+
"text": [
|
1480 |
+
"478.595 K model parameters\n"
|
1481 |
+
]
|
1482 |
+
}
|
1483 |
+
],
|
1484 |
+
"source": [
|
1485 |
+
"wandb.agent(sweep_id=sweep_id, function=train)"
|
1486 |
+
]
|
1487 |
+
},
|
1488 |
+
{
|
1489 |
+
"cell_type": "markdown",
|
1490 |
+
"metadata": {
|
1491 |
+
"id": "cNtTaEc6kxuC"
|
1492 |
+
},
|
1493 |
+
"source": [
|
1494 |
+
"# Test Time\n",
|
1495 |
+
"Since this is the best model(validation accuracy) , we will train it on both train and validation data.\n",
|
1496 |
+
"We will then test the model on test data"
|
1497 |
+
]
|
1498 |
+
},
|
1499 |
+
{
|
1500 |
+
"cell_type": "markdown",
|
1501 |
+
"metadata": {
|
1502 |
+
"id": "QcgfjfD9lvWJ"
|
1503 |
+
},
|
1504 |
+
"source": [
|
1505 |
+
"## Best Hyperparameter from validation"
|
1506 |
+
]
|
1507 |
+
},
|
1508 |
+
{
|
1509 |
+
"cell_type": "code",
|
1510 |
+
"execution_count": null,
|
1511 |
+
"metadata": {
|
1512 |
+
"execution": {
|
1513 |
+
"iopub.execute_input": "2024-04-06T15:11:46.239015Z",
|
1514 |
+
"iopub.status.busy": "2024-04-06T15:11:46.237962Z",
|
1515 |
+
"iopub.status.idle": "2024-04-06T15:11:46.337285Z",
|
1516 |
+
"shell.execute_reply": "2024-04-06T15:11:46.336384Z",
|
1517 |
+
"shell.execute_reply.started": "2024-04-06T15:11:46.238979Z"
|
1518 |
+
},
|
1519 |
+
"id": "q7SXqJhekxuC",
|
1520 |
+
"outputId": "17c0dfd2-2e0b-4449-80fe-9f7a2ce68c28",
|
1521 |
+
"trusted": true
|
1522 |
+
},
|
1523 |
+
"outputs": [
|
1524 |
+
{
|
1525 |
+
"name": "stdout",
|
1526 |
+
"output_type": "stream",
|
1527 |
+
"text": [
|
1528 |
+
" \n"
|
1529 |
+
]
|
1530 |
+
}
|
1531 |
+
],
|
1532 |
+
"source": [
|
1533 |
+
"n_embd = 128\n",
|
1534 |
+
"batch_size = 64\n",
|
1535 |
+
"learning_rate = 3e-3\n",
|
1536 |
+
"n_head = 8 # other options factors of 32 like 2, 8\n",
|
1537 |
+
"n_layers = 6\n",
|
1538 |
+
"dropout = 0.1\n",
|
1539 |
+
"epochs = 200\n",
|
1540 |
+
"\n",
|
1541 |
+
"encoder = Encoder(n_embd, n_head, n_layers, dropout)\n",
|
1542 |
+
"decoder = Decoder(n_embd, n_head, n_layers, dropout)\n",
|
1543 |
+
"encoder.to(device)\n",
|
1544 |
+
"decoder.to(device)\n",
|
1545 |
+
"print(\" \")"
|
1546 |
+
]
|
1547 |
+
},
|
1548 |
+
{
|
1549 |
+
"cell_type": "markdown",
|
1550 |
+
"metadata": {
|
1551 |
+
"id": "P0-9k1L6l0iZ"
|
1552 |
+
},
|
1553 |
+
"source": [
|
1554 |
+
"## Train on train_data + val_data"
|
1555 |
+
]
|
1556 |
+
},
|
1557 |
+
{
|
1558 |
+
"cell_type": "code",
|
1559 |
+
"execution_count": null,
|
1560 |
+
"metadata": {
|
1561 |
+
"execution": {
|
1562 |
+
"iopub.execute_input": "2024-04-06T15:11:51.054081Z",
|
1563 |
+
"iopub.status.busy": "2024-04-06T15:11:51.053142Z",
|
1564 |
+
"iopub.status.idle": "2024-04-06T17:55:02.351999Z",
|
1565 |
+
"shell.execute_reply": "2024-04-06T17:55:02.350323Z",
|
1566 |
+
"shell.execute_reply.started": "2024-04-06T15:11:51.054049Z"
|
1567 |
+
},
|
1568 |
+
"id": "TQVFJyvlTMjS",
|
1569 |
+
"trusted": true
|
1570 |
+
},
|
1571 |
+
"outputs": [],
|
1572 |
+
"source": [
|
1573 |
+
"\n",
|
1574 |
+
"# print the number of parameters in the model\n",
|
1575 |
+
"print(sum([p.numel() for p in encoder.parameters()] + [p.numel() for p in decoder.parameters()])/1e3, 'K model parameters')\n",
|
1576 |
+
"\n",
|
1577 |
+
"train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)\n",
|
1578 |
+
"val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True)\n",
|
1579 |
+
"\n",
|
1580 |
+
"# create a PyTorch optimizer\n",
|
1581 |
+
"encoder_optimizer = torch.optim.AdamW(encoder.parameters(), lr=learning_rate)\n",
|
1582 |
+
"decoder_optimizer = torch.optim.AdamW(decoder.parameters(), lr=learning_rate)\n",
|
1583 |
+
"\n",
|
1584 |
+
"# print('Step | Training Loss | Validation Loss | Training Accuracy % | Validation Accuracy %')\n",
|
1585 |
+
"\n",
|
1586 |
+
"least_error = float('inf')\n",
|
1587 |
+
"patience = 20 # The number of epochs without improvement to wait before stopping\n",
|
1588 |
+
"no_improvement = 0\n",
|
1589 |
+
"\n",
|
1590 |
+
"for i in range(epochs):\n",
|
1591 |
+
" running_loss = 0.0\n",
|
1592 |
+
" train_correct = 0\n",
|
1593 |
+
"\n",
|
1594 |
+
" encoder.train()\n",
|
1595 |
+
" decoder.train()\n",
|
1596 |
+
"\n",
|
1597 |
+
" for j,(train_x,train_y) in enumerate(train_loader):\n",
|
1598 |
+
" train_x = train_x.to(device)\n",
|
1599 |
+
" train_y = train_y.to(device)\n",
|
1600 |
+
"\n",
|
1601 |
+
" encoder_optimizer.zero_grad(set_to_none=True)\n",
|
1602 |
+
" decoder_optimizer.zero_grad(set_to_none=True)\n",
|
1603 |
+
"\n",
|
1604 |
+
" encoder_output = encoder(train_x)\n",
|
1605 |
+
" logits, loss = decoder(train_y[:, :-1], encoder_output, train_y[:, 1:])\n",
|
1606 |
+
"\n",
|
1607 |
+
" encoder_optimizer.zero_grad(set_to_none=True)\n",
|
1608 |
+
" decoder_optimizer.zero_grad(set_to_none=True)\n",
|
1609 |
+
" loss.backward()\n",
|
1610 |
+
" encoder_optimizer.step()\n",
|
1611 |
+
" decoder_optimizer.step()\n",
|
1612 |
+
"\n",
|
1613 |
+
" running_loss += loss\n",
|
1614 |
+
" pred_decoder_output = torch.argmax(logits, dim=-1)\n",
|
1615 |
+
" # print(pred_decoder_output, \" target: \", train_y[:, 1:])\n",
|
1616 |
+
" train_correct += (pred_decoder_output == train_y[:, 1:]).sum().item()\n",
|
1617 |
+
"\n",
|
1618 |
+
" for j,(train_x,train_y) in enumerate(val_loader):\n",
|
1619 |
+
" train_x = train_x.to(device)\n",
|
1620 |
+
" train_y = train_y.to(device)\n",
|
1621 |
+
"\n",
|
1622 |
+
" encoder_optimizer.zero_grad(set_to_none=True)\n",
|
1623 |
+
" decoder_optimizer.zero_grad(set_to_none=True)\n",
|
1624 |
+
"\n",
|
1625 |
+
" encoder_output = encoder(train_x)\n",
|
1626 |
+
" logits, loss = decoder(train_y[:, :-1], encoder_output, train_y[:, 1:])\n",
|
1627 |
+
"\n",
|
1628 |
+
" encoder_optimizer.zero_grad(set_to_none=True)\n",
|
1629 |
+
" decoder_optimizer.zero_grad(set_to_none=True)\n",
|
1630 |
+
" loss.backward()\n",
|
1631 |
+
" encoder_optimizer.step()\n",
|
1632 |
+
" decoder_optimizer.step()\n",
|
1633 |
+
"\n",
|
1634 |
+
" running_loss += loss\n",
|
1635 |
+
" pred_decoder_output = torch.argmax(logits, dim=-1)\n",
|
1636 |
+
" # print(pred_decoder_output, \" target: \", train_y[:, 1:])\n",
|
1637 |
+
" train_correct += (pred_decoder_output == train_y[:, 1:]).sum().item()\n",
|
1638 |
+
"\n",
|
1639 |
+
"\n",
|
1640 |
+
" metrics = {\n",
|
1641 |
+
" \"train_loss\": running_loss.cpu().detach().numpy() / (len(train_data)+len(val_data)),\n",
|
1642 |
+
" \"train_acc\": ((train_correct*100) / ((len(train_data)+len(val_data))* (decoder_block_size-1))),\n",
|
1643 |
+
" }\n",
|
1644 |
+
" if i % 5 == 0:\n",
|
1645 |
+
" print(\"Step: \",i)\n",
|
1646 |
+
" print(\"train_loss: \", metrics[\"train_loss\"])\n",
|
1647 |
+
" print(\"train_acc: \", metrics[\"train_acc\"])"
|
1648 |
+
]
|
1649 |
+
},
|
1650 |
+
{
|
1651 |
+
"cell_type": "code",
|
1652 |
+
"execution_count": null,
|
1653 |
+
"metadata": {
|
1654 |
+
"execution": {
|
1655 |
+
"iopub.execute_input": "2024-04-06T00:22:11.853957Z",
|
1656 |
+
"iopub.status.busy": "2024-04-06T00:22:11.852912Z",
|
1657 |
+
"iopub.status.idle": "2024-04-06T00:22:11.923978Z",
|
1658 |
+
"shell.execute_reply": "2024-04-06T00:22:11.923143Z",
|
1659 |
+
"shell.execute_reply.started": "2024-04-06T00:22:11.853919Z"
|
1660 |
+
},
|
1661 |
+
"id": "hAjg5s0IkxuC",
|
1662 |
+
"trusted": true
|
1663 |
+
},
|
1664 |
+
"outputs": [],
|
1665 |
+
"source": [
|
1666 |
+
"PATH = '/kaggle/working/encoder.pth'\n",
|
1667 |
+
"torch.save(encoder, PATH)\n",
|
1668 |
+
"PATH = '/kaggle/working/decoder.pth'\n",
|
1669 |
+
"torch.save(encoder, PATH)"
|
1670 |
+
]
|
1671 |
+
},
|
1672 |
+
{
|
1673 |
+
"cell_type": "markdown",
|
1674 |
+
"metadata": {
|
1675 |
+
"id": "x4M3aMxTl-zb"
|
1676 |
+
},
|
1677 |
+
"source": [
|
1678 |
+
"## generate output sequence"
|
1679 |
+
]
|
1680 |
+
},
|
1681 |
+
{
|
1682 |
+
"cell_type": "code",
|
1683 |
+
"execution_count": null,
|
1684 |
+
"metadata": {
|
1685 |
+
"execution": {
|
1686 |
+
"iopub.execute_input": "2024-04-06T12:29:19.489092Z",
|
1687 |
+
"iopub.status.busy": "2024-04-06T12:29:19.488711Z",
|
1688 |
+
"iopub.status.idle": "2024-04-06T12:29:19.496406Z",
|
1689 |
+
"shell.execute_reply": "2024-04-06T12:29:19.495353Z",
|
1690 |
+
"shell.execute_reply.started": "2024-04-06T12:29:19.489065Z"
|
1691 |
+
},
|
1692 |
+
"id": "mfIxu6njkxuD",
|
1693 |
+
"trusted": true
|
1694 |
+
},
|
1695 |
+
"outputs": [],
|
1696 |
+
"source": [
|
1697 |
+
"def generate(input):\n",
|
1698 |
+
" B, T = input.shape\n",
|
1699 |
+
" encoder_output = encoder(input)\n",
|
1700 |
+
" idx = torch.full((B, 1), 2, dtype=torch.long, device=device) # (B,1)\n",
|
1701 |
+
"\n",
|
1702 |
+
" # idx is (B, T) array of indices in the current context\n",
|
1703 |
+
" for _ in range(decoder_block_size-1):\n",
|
1704 |
+
" # get the predictions\n",
|
1705 |
+
" logits, loss = decoder(idx, encoder_output) # logits (B, T, vocab_size)\n",
|
1706 |
+
" # focus only on the last time step\n",
|
1707 |
+
" logits = logits[:, -1, :] # becomes (B, C)\n",
|
1708 |
+
" # apply softmax to get probabilities\n",
|
1709 |
+
" idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (B, 1)\n",
|
1710 |
+
" # append sampled index to the running sequence\n",
|
1711 |
+
" idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n",
|
1712 |
+
" return idx"
|
1713 |
+
]
|
1714 |
+
},
|
1715 |
+
{
|
1716 |
+
"cell_type": "markdown",
|
1717 |
+
"metadata": {
|
1718 |
+
"id": "BeB2nYeFmXy8"
|
1719 |
+
},
|
1720 |
+
"source": [
|
1721 |
+
"## Check Test Accuracy"
|
1722 |
+
]
|
1723 |
+
},
|
1724 |
+
{
|
1725 |
+
"cell_type": "code",
|
1726 |
+
"execution_count": null,
|
1727 |
+
"metadata": {
|
1728 |
+
"execution": {
|
1729 |
+
"iopub.execute_input": "2024-04-06T18:00:25.146854Z",
|
1730 |
+
"iopub.status.busy": "2024-04-06T18:00:25.146119Z",
|
1731 |
+
"iopub.status.idle": "2024-04-06T18:00:25.156303Z",
|
1732 |
+
"shell.execute_reply": "2024-04-06T18:00:25.155453Z",
|
1733 |
+
"shell.execute_reply.started": "2024-04-06T18:00:25.146826Z"
|
1734 |
+
},
|
1735 |
+
"id": "dIzXiSLBkxuD",
|
1736 |
+
"outputId": "ebe1d201-32bb-4372-e64a-62ebe173799d",
|
1737 |
+
"trusted": true
|
1738 |
+
},
|
1739 |
+
"outputs": [
|
1740 |
+
{
|
1741 |
+
"name": "stdout",
|
1742 |
+
"output_type": "stream",
|
1743 |
+
"text": [
|
1744 |
+
"test accuracy(word level) : 67.2188\n"
|
1745 |
+
]
|
1746 |
+
}
|
1747 |
+
],
|
1748 |
+
"source": [
|
1749 |
+
"def check():\n",
|
1750 |
+
"## validation code\n",
|
1751 |
+
" running_loss_val, val_correct = 0, 0\n",
|
1752 |
+
" encoder.eval()\n",
|
1753 |
+
" decoder.eval()\n",
|
1754 |
+
" test_loader = DataLoader(test_data, batch_size=64, shuffle=True)\n",
|
1755 |
+
" for _ in range(50):\n",
|
1756 |
+
" val_x,val_y = next(iter(test_loader))\n",
|
1757 |
+
"\n",
|
1758 |
+
" val_x = val_x.to(device)\n",
|
1759 |
+
" val_y = val_y.to(device)\n",
|
1760 |
+
"\n",
|
1761 |
+
" output = generate(val_x)\n",
|
1762 |
+
"\n",
|
1763 |
+
" encoder_output = encoder(val_x)\n",
|
1764 |
+
" logits, loss = decoder(val_y[:, :-1], encoder_output, val_y[:, 1:])\n",
|
1765 |
+
"\n",
|
1766 |
+
" running_loss_val += loss\n",
|
1767 |
+
" # checking val_correct for the whole sequence\n",
|
1768 |
+
" val_correct += torch.sum(torch.sum(output[:, 1:] != val_y[:, 1:], dim=-1) == 0)\n",
|
1769 |
+
"\n",
|
1770 |
+
" print(\"test accuracy(word level) : \", ((val_correct.cpu().detach().numpy()*100) / len(test_data)))\n",
|
1771 |
+
"\n",
|
1772 |
+
"check()"
|
1773 |
+
]
|
1774 |
+
},
|
1775 |
+
{
|
1776 |
+
"cell_type": "markdown",
|
1777 |
+
"metadata": {
|
1778 |
+
"id": "LDP4KvWdFnIL"
|
1779 |
+
},
|
1780 |
+
"source": [
|
1781 |
+
"# Plotting the Attention HeatMaps"
|
1782 |
+
]
|
1783 |
+
},
|
1784 |
+
{
|
1785 |
+
"cell_type": "code",
|
1786 |
+
"execution_count": null,
|
1787 |
+
"metadata": {
|
1788 |
+
"id": "4WfJEdcgFmiI",
|
1789 |
+
"trusted": true
|
1790 |
+
},
|
1791 |
+
"outputs": [],
|
1792 |
+
"source": [
|
1793 |
+
"import matplotlib.pyplot as plt\n",
|
1794 |
+
"import numpy as np\n",
|
1795 |
+
"from matplotlib.font_manager import FontProperties\n",
|
1796 |
+
"tel_font = FontProperties(fname = 'TiroDevanagariHindi-Regular.ttf')\n",
|
1797 |
+
"# Assuming you have attention_weights of shape (batch_size, output_sequence_length, batch_size, input_sequence_length)\n",
|
1798 |
+
"# and prediction_matrix of shape (batch_size, output_sequence_length)\n",
|
1799 |
+
"# and input_matrix of shape (batch_size, input_sequence_length)\n",
|
1800 |
+
"\n",
|
1801 |
+
"# Define the grid dimensions\n",
|
1802 |
+
"rows = int(np.ceil(np.sqrt(12)))\n",
|
1803 |
+
"cols = int(np.ceil(12 / rows))\n",
|
1804 |
+
"\n",
|
1805 |
+
"# Create a figure and subplots\n",
|
1806 |
+
"fig, axes = plt.subplots(rows, cols, figsize=(9, 9))\n",
|
1807 |
+
"\n",
|
1808 |
+
"for i, ax in enumerate(axes.flatten()):\n",
|
1809 |
+
" if i < 12:\n",
|
1810 |
+
" prediction = [opLang.index2char[j.item()] for j in pred[i+1]]\n",
|
1811 |
+
"\n",
|
1812 |
+
" pred_word=\"\"\n",
|
1813 |
+
" input_word=\"\"\n",
|
1814 |
+
"\n",
|
1815 |
+
" for j in range(len(prediction)):\n",
|
1816 |
+
" # Ignore padding\n",
|
1817 |
+
" if(prediction[j] != '#'):\n",
|
1818 |
+
" pred_word += prediction[j]\n",
|
1819 |
+
" else :\n",
|
1820 |
+
" break\n",
|
1821 |
+
" input_seq = [ipLang.index2char[j.item()] for j in testData[i][0]]\n",
|
1822 |
+
"\n",
|
1823 |
+
" for j in range(len(input_seq)):\n",
|
1824 |
+
" if(input_seq[j] != '#'):\n",
|
1825 |
+
" input_word += input_seq[j]\n",
|
1826 |
+
" else :\n",
|
1827 |
+
" break\n",
|
1828 |
+
" attn_weights = atten_weights[i, :len(pred_word), :len(input_word)].detach().cpu().numpy()\n",
|
1829 |
+
" ax.imshow(attn_weights.T, cmap='hot', interpolation='nearest')\n",
|
1830 |
+
" ax.xaxis.set_label_position('top')\n",
|
1831 |
+
" ax.set_title(f'Example {i+1}')\n",
|
1832 |
+
" ax.set_xlabel('Output predicted')\n",
|
1833 |
+
" ax.set_ylabel('Input word')\n",
|
1834 |
+
" ax.set_xticks(np.arange(len(pred_word)))\n",
|
1835 |
+
" ax.set_xticklabels(pred_word, rotation = 90, fontproperties = tel_font,fontdict={'fontsize':8})\n",
|
1836 |
+
" ax.xaxis.tick_top()\n",
|
1837 |
+
"\n",
|
1838 |
+
" ax.set_yticks(np.arange(len(input_word)))\n",
|
1839 |
+
" ax.set_yticklabels(input_word, rotation=90)\n",
|
1840 |
+
"\n",
|
1841 |
+
"\n",
|
1842 |
+
"\n",
|
1843 |
+
"# Adjust the spacing between subplots\n",
|
1844 |
+
"plt.tight_layout()\n",
|
1845 |
+
"\n",
|
1846 |
+
"# Show the plot\n",
|
1847 |
+
"plt.show()\n",
|
1848 |
+
"wandb.init(project='CS6910_Assignment_3')\n",
|
1849 |
+
"\n",
|
1850 |
+
"# Convert the matplotlib figure to an image\n",
|
1851 |
+
"fig.canvas.draw()\n",
|
1852 |
+
"image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')\n",
|
1853 |
+
"image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))\n",
|
1854 |
+
"\n",
|
1855 |
+
"# Log the image in wandb\n",
|
1856 |
+
"wandb.log({\"attention_heatmaps\": [wandb.Image(image)]})"
|
1857 |
+
]
|
1858 |
+
},
|
1859 |
+
{
|
1860 |
+
"cell_type": "code",
|
1861 |
+
"execution_count": null,
|
1862 |
+
"metadata": {
|
1863 |
+
"id": "FnHR_oql6-S4"
|
1864 |
+
},
|
1865 |
+
"outputs": [],
|
1866 |
+
"source": []
|
1867 |
+
}
|
1868 |
+
],
|
1869 |
+
"metadata": {
|
1870 |
+
"accelerator": "GPU",
|
1871 |
+
"colab": {
|
1872 |
+
"collapsed_sections": [
|
1873 |
+
"hRdpoWePeYHn",
|
1874 |
+
"44xIRolL_T_d",
|
1875 |
+
"XdltQ7oJCq1j",
|
1876 |
+
"GgPU486JC8Mz",
|
1877 |
+
"658W9RARGEUf",
|
1878 |
+
"q7fAgs5uQni_",
|
1879 |
+
"n4rGh7vuQqaa",
|
1880 |
+
"nvyRJWUUbR2f",
|
1881 |
+
"8ETW0BG_Pa24",
|
1882 |
+
"MQPGy32rnD3V",
|
1883 |
+
"z_aYZvDD1OHU",
|
1884 |
+
"pKvBd5mKf0Hf",
|
1885 |
+
"FYMa5jTQRUaB",
|
1886 |
+
"zfuv5FoA1wt2",
|
1887 |
+
"W7CYNChRGuGK"
|
1888 |
+
],
|
1889 |
+
"gpuType": "T4",
|
1890 |
+
"include_colab_link": true,
|
1891 |
+
"provenance": [],
|
1892 |
+
"toc_visible": true
|
1893 |
+
},
|
1894 |
+
"kaggle": {
|
1895 |
+
"accelerator": "gpu",
|
1896 |
+
"dataSources": [
|
1897 |
+
{
|
1898 |
+
"datasetId": 4721249,
|
1899 |
+
"sourceId": 8013732,
|
1900 |
+
"sourceType": "datasetVersion"
|
1901 |
+
}
|
1902 |
+
],
|
1903 |
+
"dockerImageVersionId": 30674,
|
1904 |
+
"isGpuEnabled": true,
|
1905 |
+
"isInternetEnabled": true,
|
1906 |
+
"language": "python",
|
1907 |
+
"sourceType": "notebook"
|
1908 |
+
},
|
1909 |
+
"kernelspec": {
|
1910 |
+
"display_name": "Python 3",
|
1911 |
+
"language": "python",
|
1912 |
+
"name": "python3"
|
1913 |
+
},
|
1914 |
+
"language_info": {
|
1915 |
+
"codemirror_mode": {
|
1916 |
+
"name": "ipython",
|
1917 |
+
"version": 3
|
1918 |
+
},
|
1919 |
+
"file_extension": ".py",
|
1920 |
+
"mimetype": "text/x-python",
|
1921 |
+
"name": "python",
|
1922 |
+
"nbconvert_exporter": "python",
|
1923 |
+
"pygments_lexer": "ipython3",
|
1924 |
+
"version": "3.10.13"
|
1925 |
+
}
|
1926 |
+
},
|
1927 |
+
"nbformat": 4,
|
1928 |
+
"nbformat_minor": 0
|
1929 |
+
}
|
predictions_attention/predictions.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
predictions_transformer/predictions.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
predictions_vanilla/predictions _vanilla.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
requirements.txt
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles==23.2.1
|
2 |
+
annotated-types==0.7.0
|
3 |
+
anyio==4.6.0
|
4 |
+
asttokens==2.4.1
|
5 |
+
certifi==2024.8.30
|
6 |
+
charset-normalizer==3.3.2
|
7 |
+
click==8.1.7
|
8 |
+
colorama==0.4.6
|
9 |
+
comm==0.2.2
|
10 |
+
contourpy==1.3.0
|
11 |
+
cycler==0.12.1
|
12 |
+
debugpy==1.8.6
|
13 |
+
decorator==5.1.1
|
14 |
+
docker-pycreds==0.4.0
|
15 |
+
executing==2.1.0
|
16 |
+
fastapi==0.115.0
|
17 |
+
ffmpy==0.4.0
|
18 |
+
filelock==3.16.1
|
19 |
+
fonttools==4.54.1
|
20 |
+
fsspec==2024.9.0
|
21 |
+
gitdb==4.0.11
|
22 |
+
GitPython==3.1.43
|
23 |
+
gradio==4.44.1
|
24 |
+
gradio_client==1.3.0
|
25 |
+
h11==0.14.0
|
26 |
+
httpcore==1.0.6
|
27 |
+
httpx==0.27.2
|
28 |
+
huggingface-hub==0.25.1
|
29 |
+
idna==3.10
|
30 |
+
importlib_resources==6.4.5
|
31 |
+
ipykernel==6.29.5
|
32 |
+
ipython==8.28.0
|
33 |
+
jedi==0.19.1
|
34 |
+
Jinja2==3.1.4
|
35 |
+
jupyter_client==8.6.3
|
36 |
+
jupyter_core==5.7.2
|
37 |
+
kiwisolver==1.4.7
|
38 |
+
markdown-it-py==3.0.0
|
39 |
+
MarkupSafe==2.1.5
|
40 |
+
matplotlib==3.9.2
|
41 |
+
matplotlib-inline==0.1.7
|
42 |
+
mdurl==0.1.2
|
43 |
+
mpmath==1.3.0
|
44 |
+
nest-asyncio==1.6.0
|
45 |
+
networkx==3.3
|
46 |
+
numpy==2.1.1
|
47 |
+
orjson==3.10.7
|
48 |
+
packaging==24.1
|
49 |
+
pandas==2.2.3
|
50 |
+
parso==0.8.4
|
51 |
+
pillow==10.4.0
|
52 |
+
platformdirs==4.3.6
|
53 |
+
prompt_toolkit==3.0.48
|
54 |
+
protobuf==5.28.2
|
55 |
+
psutil==6.0.0
|
56 |
+
pure_eval==0.2.3
|
57 |
+
pydantic==2.9.2
|
58 |
+
pydantic_core==2.23.4
|
59 |
+
pydub==0.25.1
|
60 |
+
Pygments==2.18.0
|
61 |
+
pyparsing==3.1.4
|
62 |
+
python-dateutil==2.9.0.post0
|
63 |
+
python-multipart==0.0.12
|
64 |
+
pytz==2024.2
|
65 |
+
PyYAML==6.0.2
|
66 |
+
pyzmq==26.2.0
|
67 |
+
requests==2.32.3
|
68 |
+
rich==13.9.1
|
69 |
+
ruff==0.6.8
|
70 |
+
semantic-version==2.10.0
|
71 |
+
sentry-sdk==2.15.0
|
72 |
+
setproctitle==1.3.3
|
73 |
+
setuptools==75.1.0
|
74 |
+
shellingham==1.5.4
|
75 |
+
six==1.16.0
|
76 |
+
smmap==5.0.1
|
77 |
+
sniffio==1.3.1
|
78 |
+
stack-data==0.6.3
|
79 |
+
starlette==0.38.6
|
80 |
+
sympy==1.13.3
|
81 |
+
tomlkit==0.12.0
|
82 |
+
torch==2.4.1
|
83 |
+
tornado==6.4.1
|
84 |
+
tqdm==4.66.5
|
85 |
+
traitlets==5.14.3
|
86 |
+
typer==0.12.5
|
87 |
+
typing_extensions==4.12.2
|
88 |
+
tzdata==2024.2
|
89 |
+
urllib3==2.2.3
|
90 |
+
uvicorn==0.31.0
|
91 |
+
wandb==0.18.3
|
92 |
+
wcwidth==0.2.13
|
93 |
+
websockets==12.0
|
src/decoder.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from src.helper import get_cell
|
5 |
+
|
6 |
+
class Decoder(nn.Module):
|
7 |
+
def __init__(self,
|
8 |
+
out_sz: int,
|
9 |
+
embed_sz: int,
|
10 |
+
hidden_sz: int,
|
11 |
+
cell_type: str,
|
12 |
+
n_layers: int,
|
13 |
+
dropout: float,
|
14 |
+
device: str):
|
15 |
+
|
16 |
+
super(Decoder, self).__init__()
|
17 |
+
self.hidden_sz = hidden_sz
|
18 |
+
self.n_layers = n_layers
|
19 |
+
self.dropout = dropout
|
20 |
+
self.cell_type = cell_type
|
21 |
+
self.embedding = nn.Embedding(out_sz, embed_sz)
|
22 |
+
self.device = device
|
23 |
+
|
24 |
+
self.rnn = get_cell(cell_type)(input_size = embed_sz,
|
25 |
+
hidden_size = hidden_sz,
|
26 |
+
num_layers = n_layers,
|
27 |
+
dropout = dropout)
|
28 |
+
|
29 |
+
self.out = nn.Linear(hidden_sz, out_sz)
|
30 |
+
self.softmax = nn.LogSoftmax(dim=1)
|
31 |
+
|
32 |
+
def forward(self, input, hidden, cell):
|
33 |
+
output = self.embedding(input).view(1, 1, -1)
|
34 |
+
output = F.relu(output)
|
35 |
+
|
36 |
+
if(self.cell_type == "LSTM"):
|
37 |
+
output, (hidden, cell) = self.rnn(output, (hidden, cell))
|
38 |
+
else:
|
39 |
+
output, hidden = self.rnn(output, hidden)
|
40 |
+
|
41 |
+
output = self.softmax(self.out(output[0]))
|
42 |
+
return output, hidden, cell
|
43 |
+
|
44 |
+
def initHidden(self):
|
45 |
+
return torch.zeros(self.n_layers, 1, self.hidden_sz, device=self.device)
|
46 |
+
|
src/encoder.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ort torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from src.helper import get_cell
|
4 |
+
|
5 |
+
class Encoder(nn.Module):
|
6 |
+
def __init__(self,
|
7 |
+
in_sz: int,
|
8 |
+
embed_sz: int,
|
9 |
+
hidden_sz: int,
|
10 |
+
cell_type: str,
|
11 |
+
n_layers: int,
|
12 |
+
dropout: float,
|
13 |
+
device: str):
|
14 |
+
|
15 |
+
super(Encoder, self).__init__()
|
16 |
+
self.hidden_sz = hidden_sz
|
17 |
+
self.n_layers = n_layers
|
18 |
+
self.dropout = dropout
|
19 |
+
self.cell_type = cell_type
|
20 |
+
self.embedding = nn.Embedding(in_sz, embed_sz)
|
21 |
+
self.device = device
|
22 |
+
|
23 |
+
self.rnn = get_cell(cell_type)(input_size = embed_sz,
|
24 |
+
hidden_size = hidden_sz,
|
25 |
+
num_layers = n_layers,
|
26 |
+
dropout = dropout)
|
27 |
+
|
28 |
+
def forward(self, input, hidden, cell):
|
29 |
+
embedded = self.embedding(input).view(1, 1, -1)
|
30 |
+
|
31 |
+
if(self.cell_type == "LSTM"):
|
32 |
+
output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
|
33 |
+
else:
|
34 |
+
output, hidden = self.rnn(embedded, hidden)
|
35 |
+
|
36 |
+
return output, hidden, cell
|
37 |
+
|
38 |
+
def initHidden(self):
|
39 |
+
return torch.zeros(self.n_layers, 1, self.hidden_sz, device=self.device)
|
src/helper.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.optim as optim
|
5 |
+
from src.language import Language, EOS_token
|
6 |
+
|
7 |
+
def get_data(lang: str, type: str) -> list[list[str]]:
|
8 |
+
"""
|
9 |
+
Returns: 'pairs': list of [input_word, target_word] pairs
|
10 |
+
"""
|
11 |
+
path = "./aksharantar_sampled/{}/{}_{}.csv".format(lang, lang, type)
|
12 |
+
df = pd.read_csv(path, header=None)
|
13 |
+
pairs = df.values.tolist()
|
14 |
+
return pairs
|
15 |
+
|
16 |
+
def get_languages(lang: str):
|
17 |
+
"""
|
18 |
+
Returns
|
19 |
+
1. input_lang: input language - English
|
20 |
+
2. output_lang: output language - Given language
|
21 |
+
3. pairs: list of [input_word, target_word] pairs
|
22 |
+
"""
|
23 |
+
input_lang = Language('eng')
|
24 |
+
output_lang = Language(lang)
|
25 |
+
pairs = get_data(lang, "train")
|
26 |
+
for pair in pairs:
|
27 |
+
input_lang.addWord(pair[0])
|
28 |
+
output_lang.addWord(pair[1])
|
29 |
+
return input_lang, output_lang, pairs
|
30 |
+
|
31 |
+
def get_cell(cell_type: str):
|
32 |
+
if cell_type == "LSTM":
|
33 |
+
return nn.LSTM
|
34 |
+
elif cell_type == "GRU":
|
35 |
+
return nn.GRU
|
36 |
+
elif cell_type == "RNN":
|
37 |
+
return nn.RNN
|
38 |
+
else:
|
39 |
+
raise Exception("Invalid cell type")
|
40 |
+
|
41 |
+
def get_optimizer(optimizer: str):
|
42 |
+
if optimizer == "SGD":
|
43 |
+
return optim.SGD
|
44 |
+
elif optimizer == "ADAM":
|
45 |
+
return optim.Adam
|
46 |
+
else:
|
47 |
+
raise Exception("Invalid optimizer")
|
48 |
+
|
49 |
+
def indexesFromWord(lang:Language, word:str):
|
50 |
+
return [lang.word2index[char] for char in word]
|
51 |
+
|
52 |
+
def tensorFromWord(lang:Language, word:str, device:str):
|
53 |
+
indexes = indexesFromWord(lang, word)
|
54 |
+
indexes.append(EOS_token)
|
55 |
+
return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)
|
56 |
+
|
57 |
+
def tensorsFromPair(input_lang:Language, output_lang:Language, pair:list[str], device:str):
|
58 |
+
input_tensor = tensorFromWord(input_lang, pair[0], device)
|
59 |
+
target_tensor = tensorFromWord(output_lang, pair[1], device)
|
60 |
+
return (input_tensor, target_tensor)
|
src/language.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Language Model
|
2 |
+
SOS_token = 0
|
3 |
+
EOS_token = 1
|
4 |
+
|
5 |
+
class Language:
|
6 |
+
def __init__(self, name):
|
7 |
+
self.name = name
|
8 |
+
self.word2index = {}
|
9 |
+
self.word2count = {}
|
10 |
+
self.index2word = {SOS_token: "<", EOS_token: ">"}
|
11 |
+
self.n_chars = 2 # Count SOS and EOS
|
12 |
+
|
13 |
+
def addWord(self, word):
|
14 |
+
for char in word:
|
15 |
+
self.addChar(char)
|
16 |
+
|
17 |
+
def addChar(self, char):
|
18 |
+
if char not in self.word2index:
|
19 |
+
self.word2index[char] = self.n_chars
|
20 |
+
self.word2count[char] = 1
|
21 |
+
self.index2word[self.n_chars] = char
|
22 |
+
self.n_chars += 1
|
23 |
+
else:
|
24 |
+
self.word2count[char] += 1
|
src/translator.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from src.helper import get_optimizer, tensorsFromPair, get_languages, tensorFromWord, get_data
|
5 |
+
from src.language import SOS_token, EOS_token
|
6 |
+
from src.encoder import Encoder
|
7 |
+
from src.decoder import Decoder
|
8 |
+
import random
|
9 |
+
import time
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
PRINT_EVERY = 5000
|
13 |
+
PLOT_EVERY = 100
|
14 |
+
|
15 |
+
class Translator:
|
16 |
+
def __init__(self, lang: str, params: dict, device: str):
|
17 |
+
self.lang = lang
|
18 |
+
self.input_lang, self.output_lang, self.pairs = get_languages(self.lang)
|
19 |
+
self.input_size = self.input_lang.n_chars
|
20 |
+
self.output_size = self.output_lang.n_chars
|
21 |
+
self.device = device
|
22 |
+
|
23 |
+
self.training_pairs = [tensorsFromPair(self.input_lang, self.output_lang, pair, self.device) for pair in self.pairs]
|
24 |
+
|
25 |
+
self.encoder = Encoder(in_sz = self.input_size,
|
26 |
+
embed_sz = params["embed_size"],
|
27 |
+
hidden_sz = params["hidden_size"],
|
28 |
+
cell_type = params["cell_type"],
|
29 |
+
n_layers = params["num_layers"],
|
30 |
+
dropout = params["dropout"],
|
31 |
+
device=self.device).to(self.device)
|
32 |
+
|
33 |
+
self.decoder = Decoder(out_sz = self.output_size,
|
34 |
+
embed_sz = params["embed_size"],
|
35 |
+
hidden_sz = params["hidden_size"],
|
36 |
+
cell_type = params["cell_type"],
|
37 |
+
n_layers = params["num_layers"],
|
38 |
+
dropout = params["dropout"],
|
39 |
+
device=self.device).to(self.device)
|
40 |
+
|
41 |
+
self.encoder_optimizer = get_optimizer(params["optimizer"])(self.encoder.parameters(), lr=params["learning_rate"])
|
42 |
+
self.decoder_optimizer = get_optimizer(params["optimizer"])(self.decoder.parameters(), lr=params["learning_rate"])
|
43 |
+
|
44 |
+
self.criterion = nn.NLLLoss()
|
45 |
+
|
46 |
+
self.teacher_forcing_ratio = params["teacher_forcing_ratio"]
|
47 |
+
self.max_length = params["max_length"]
|
48 |
+
|
49 |
+
def train_single(self, input_tensor, target_tensor):
|
50 |
+
encoder_hidden = self.encoder.initHidden()
|
51 |
+
encoder_cell = self.encoder.initHidden()
|
52 |
+
|
53 |
+
self.encoder_optimizer.zero_grad()
|
54 |
+
self.decoder_optimizer.zero_grad()
|
55 |
+
|
56 |
+
input_length = input_tensor.size(0)
|
57 |
+
target_length = target_tensor.size(0)
|
58 |
+
|
59 |
+
encoder_outputs = torch.zeros(self.max_length, self.encoder.hidden_sz, device=self.device)
|
60 |
+
|
61 |
+
loss = 0
|
62 |
+
|
63 |
+
for ei in range(input_length):
|
64 |
+
encoder_output, encoder_hidden, encoder_cell = self.encoder(input_tensor[ei], encoder_hidden, encoder_cell)
|
65 |
+
encoder_outputs[ei] = encoder_output[0, 0]
|
66 |
+
|
67 |
+
decoder_input = torch.tensor([[SOS_token]], device=self.device)
|
68 |
+
decoder_hidden, decoder_cell = encoder_hidden, encoder_cell
|
69 |
+
|
70 |
+
use_teacher_forcing = True if random.random() < self.teacher_forcing_ratio else False
|
71 |
+
|
72 |
+
if use_teacher_forcing:
|
73 |
+
for di in range(target_length):
|
74 |
+
decoder_output, decoder_hidden, decoder_cell = self.decoder(decoder_input, decoder_hidden, decoder_cell)
|
75 |
+
loss += self.criterion(decoder_output, target_tensor[di])
|
76 |
+
|
77 |
+
decoder_input = target_tensor[di]
|
78 |
+
else:
|
79 |
+
for di in range(target_length):
|
80 |
+
decoder_output, decoder_hidden, decoder_cell = self.decoder(decoder_input, decoder_hidden, decoder_cell)
|
81 |
+
loss += self.criterion(decoder_output, target_tensor[di])
|
82 |
+
|
83 |
+
topv, topi = decoder_output.topk(1)
|
84 |
+
decoder_input = topi.squeeze().detach()
|
85 |
+
if decoder_input.item() == EOS_token:
|
86 |
+
break
|
87 |
+
|
88 |
+
loss.backward()
|
89 |
+
self.encoder_optimizer.step()
|
90 |
+
self.decoder_optimizer.step()
|
91 |
+
|
92 |
+
return loss.item() / target_length
|
93 |
+
|
94 |
+
def train(self, iters=-1):
|
95 |
+
start_time = time.time()
|
96 |
+
plot_losses = []
|
97 |
+
print_loss_total = 0
|
98 |
+
plot_loss_total = 0
|
99 |
+
|
100 |
+
random.shuffle(self.training_pairs)
|
101 |
+
iters = len(self.training_pairs) if iters == -1 else iters
|
102 |
+
|
103 |
+
for iter in range(1, iters+1):
|
104 |
+
training_pair = self.training_pairs[iter - 1]
|
105 |
+
input_tensor = training_pair[0]
|
106 |
+
target_tensor = training_pair[1]
|
107 |
+
|
108 |
+
loss = self.train_single(input_tensor, target_tensor)
|
109 |
+
print_loss_total += loss
|
110 |
+
plot_loss_total += loss
|
111 |
+
|
112 |
+
if iter % PRINT_EVERY == 0:
|
113 |
+
print_loss_avg = print_loss_total / PRINT_EVERY
|
114 |
+
print_loss_total = 0
|
115 |
+
current_time = time.time()
|
116 |
+
print("Loss: {:.4f} | Iterations: {} | Time: {:.3f}".format(print_loss_avg, iter, current_time - start_time))
|
117 |
+
|
118 |
+
if iter % PLOT_EVERY == 0:
|
119 |
+
plot_loss_avg = plot_loss_total / PLOT_EVERY
|
120 |
+
plot_losses.append(plot_loss_avg)
|
121 |
+
plot_loss_total = 0
|
122 |
+
|
123 |
+
return plot_losses
|
124 |
+
|
125 |
+
def evaluate(self, word):
|
126 |
+
with torch.no_grad():
|
127 |
+
input_tensor = tensorFromWord(self.input_lang, word, self.device)
|
128 |
+
input_length = input_tensor.size()[0]
|
129 |
+
encoder_hidden = self.encoder.initHidden()
|
130 |
+
encoder_cell = self.encoder.initHidden()
|
131 |
+
|
132 |
+
encoder_outputs = torch.zeros(self.max_length, self.encoder.hidden_sz, device=self.device)
|
133 |
+
|
134 |
+
for ei in range(input_length):
|
135 |
+
encoder_output, encoder_hidden, encoder_cell = self.encoder(input_tensor[ei], encoder_hidden, encoder_cell)
|
136 |
+
encoder_outputs[ei] += encoder_output[0, 0]
|
137 |
+
|
138 |
+
decoder_input = torch.tensor([[SOS_token]], device=self.device)
|
139 |
+
decoder_hidden, decoder_cell = encoder_hidden, encoder_cell
|
140 |
+
|
141 |
+
decoded_chars = ""
|
142 |
+
|
143 |
+
for di in range(self.max_length):
|
144 |
+
decoder_output, decoder_hidden, decoder_cell = self.decoder(decoder_input, decoder_hidden, decoder_cell)
|
145 |
+
topv, topi = decoder_output.topk(1)
|
146 |
+
|
147 |
+
if topi.item() == EOS_token:
|
148 |
+
break
|
149 |
+
else:
|
150 |
+
decoded_chars += self.output_lang.index2word[topi.item()]
|
151 |
+
|
152 |
+
decoder_input = topi.squeeze().detach()
|
153 |
+
|
154 |
+
return decoded_chars
|
155 |
+
|
156 |
+
def test_validate(self, type:str):
|
157 |
+
pairs = get_data(self.lang, type)
|
158 |
+
accuracy = np.sum([self.evaluate(pair[0]) == pair[1] for pair in pairs])
|
159 |
+
return accuracy / len(pairs)
|
test_best_attention.py
ADDED
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch.optim as optim
|
5 |
+
import random
|
6 |
+
import numpy as np
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import pandas as pd
|
9 |
+
import time
|
10 |
+
|
11 |
+
import wandb
|
12 |
+
wandb.login()
|
13 |
+
|
14 |
+
random.seed()
|
15 |
+
|
16 |
+
import os
|
17 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
18 |
+
|
19 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
+
print(device)
|
21 |
+
|
22 |
+
# Language Model
|
23 |
+
SOS_token = 0
|
24 |
+
EOS_token = 1
|
25 |
+
|
26 |
+
class Language:
|
27 |
+
def __init__(self, name):
|
28 |
+
self.name = name
|
29 |
+
self.word2index = {}
|
30 |
+
self.word2count = {}
|
31 |
+
self.index2word = {SOS_token: "<", EOS_token: ">"}
|
32 |
+
self.n_chars = 2 # Count SOS and EOS
|
33 |
+
|
34 |
+
def addWord(self, word):
|
35 |
+
for char in word:
|
36 |
+
self.addChar(char)
|
37 |
+
|
38 |
+
def addChar(self, char):
|
39 |
+
if char not in self.word2index:
|
40 |
+
self.word2index[char] = self.n_chars
|
41 |
+
self.word2count[char] = 1
|
42 |
+
self.index2word[self.n_chars] = char
|
43 |
+
self.n_chars += 1
|
44 |
+
else:
|
45 |
+
self.word2count[char] += 1
|
46 |
+
|
47 |
+
def get_data(lang: str, type: str) -> list[list[str]]:
|
48 |
+
"""
|
49 |
+
Returns: 'pairs': list of [input_word, target_word] pairs
|
50 |
+
"""
|
51 |
+
path = "./aksharantar_sampled/{}/{}_{}.csv".format(lang, lang, type)
|
52 |
+
df = pd.read_csv(path, header=None)
|
53 |
+
pairs = df.values.tolist()
|
54 |
+
return pairs
|
55 |
+
|
56 |
+
def get_languages(lang: str):
|
57 |
+
"""
|
58 |
+
Returns
|
59 |
+
1. input_lang: input language - English
|
60 |
+
2. output_lang: output language - Given language
|
61 |
+
3. pairs: list of [input_word, target_word] pairs
|
62 |
+
"""
|
63 |
+
input_lang = Language('eng')
|
64 |
+
output_lang = Language(lang)
|
65 |
+
pairs = get_data(lang, "train")
|
66 |
+
for pair in pairs:
|
67 |
+
input_lang.addWord(pair[0])
|
68 |
+
output_lang.addWord(pair[1])
|
69 |
+
return input_lang, output_lang, pairs
|
70 |
+
|
71 |
+
def get_cell(cell_type: str):
|
72 |
+
if cell_type == "LSTM":
|
73 |
+
return nn.LSTM
|
74 |
+
elif cell_type == "GRU":
|
75 |
+
return nn.GRU
|
76 |
+
elif cell_type == "RNN":
|
77 |
+
return nn.RNN
|
78 |
+
else:
|
79 |
+
raise Exception("Invalid cell type")
|
80 |
+
|
81 |
+
def get_optimizer(optimizer: str):
|
82 |
+
if optimizer == "SGD":
|
83 |
+
return optim.SGD
|
84 |
+
elif optimizer == "ADAM":
|
85 |
+
return optim.Adam
|
86 |
+
else:
|
87 |
+
raise Exception("Invalid optimizer")
|
88 |
+
|
89 |
+
class Encoder(nn.Module):
|
90 |
+
def __init__(self,
|
91 |
+
in_sz: int,
|
92 |
+
embed_sz: int,
|
93 |
+
hidden_sz: int,
|
94 |
+
cell_type: str,
|
95 |
+
n_layers: int,
|
96 |
+
dropout: float):
|
97 |
+
|
98 |
+
super(Encoder, self).__init__()
|
99 |
+
self.hidden_sz = hidden_sz
|
100 |
+
self.n_layers = n_layers
|
101 |
+
self.dropout = dropout
|
102 |
+
self.cell_type = cell_type
|
103 |
+
self.embedding = nn.Embedding(in_sz, embed_sz)
|
104 |
+
|
105 |
+
self.rnn = get_cell(cell_type)(input_size = embed_sz,
|
106 |
+
hidden_size = hidden_sz,
|
107 |
+
num_layers = n_layers,
|
108 |
+
dropout = dropout)
|
109 |
+
|
110 |
+
def forward(self, input, hidden, cell):
|
111 |
+
embedded = self.embedding(input).view(1, 1, -1)
|
112 |
+
|
113 |
+
if(self.cell_type == "LSTM"):
|
114 |
+
output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
|
115 |
+
else:
|
116 |
+
output, hidden = self.rnn(embedded, hidden)
|
117 |
+
|
118 |
+
return output, hidden, cell
|
119 |
+
|
120 |
+
def initHidden(self):
|
121 |
+
return torch.zeros(self.n_layers, 1, self.hidden_sz, device=device)
|
122 |
+
|
123 |
+
class AttentionDecoder(nn.Module):
|
124 |
+
def __init__(self,
|
125 |
+
out_sz: int,
|
126 |
+
embed_sz: int,
|
127 |
+
hidden_sz: int,
|
128 |
+
cell_type: str,
|
129 |
+
n_layers: int,
|
130 |
+
dropout: float):
|
131 |
+
|
132 |
+
super(AttentionDecoder, self).__init__()
|
133 |
+
self.hidden_sz = hidden_sz
|
134 |
+
self.n_layers = n_layers
|
135 |
+
self.dropout = dropout
|
136 |
+
self.cell_type = cell_type
|
137 |
+
self.embedding = nn.Embedding(out_sz, embed_sz)
|
138 |
+
|
139 |
+
self.attn = nn.Linear(hidden_sz + embed_sz, 50)
|
140 |
+
self.attn_combine = nn.Linear(hidden_sz + embed_sz, hidden_sz)
|
141 |
+
|
142 |
+
self.rnn = get_cell(cell_type)(input_size = hidden_sz,
|
143 |
+
hidden_size = hidden_sz,
|
144 |
+
num_layers = n_layers,
|
145 |
+
dropout = dropout)
|
146 |
+
|
147 |
+
self.out = nn.Linear(hidden_sz, out_sz)
|
148 |
+
self.softmax = nn.LogSoftmax(dim=1)
|
149 |
+
|
150 |
+
def forward(self, input, hidden, cell, encoder_outputs):
|
151 |
+
embedding = self.embedding(input).view(1, 1, -1)
|
152 |
+
|
153 |
+
attn_weights = F.softmax(self.attn(torch.cat((embedding[0], hidden[0]), 1)), dim=1)
|
154 |
+
attn_applied = torch.bmm(attn_weights.unsqueeze(0), encoder_outputs.unsqueeze(0))
|
155 |
+
|
156 |
+
output = torch.cat((embedding[0], attn_applied[0]), 1)
|
157 |
+
output = self.attn_combine(output).unsqueeze(0)
|
158 |
+
|
159 |
+
if(self.cell_type == "LSTM"):
|
160 |
+
output, (hidden, cell) = self.rnn(output, (hidden, cell))
|
161 |
+
else:
|
162 |
+
output, hidden = self.rnn(output, hidden)
|
163 |
+
|
164 |
+
output = self.softmax(self.out(output[0]))
|
165 |
+
return output, hidden, cell, attn_weights
|
166 |
+
|
167 |
+
def initHidden(self):
|
168 |
+
return torch.zeros(self.n_layers, 1, self.hidden_sz, device=device)
|
169 |
+
|
170 |
+
def indexesFromWord(lang:Language, word:str):
|
171 |
+
return [lang.word2index[char] for char in word]
|
172 |
+
|
173 |
+
def tensorFromWord(lang:Language, word:str):
|
174 |
+
indexes = indexesFromWord(lang, word)
|
175 |
+
indexes.append(EOS_token)
|
176 |
+
return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)
|
177 |
+
|
178 |
+
def tensorsFromPair(input_lang:Language, output_lang:Language, pair:list[str]):
|
179 |
+
input_tensor = tensorFromWord(input_lang, pair[0])
|
180 |
+
target_tensor = tensorFromWord(output_lang, pair[1])
|
181 |
+
return (input_tensor, target_tensor)
|
182 |
+
|
183 |
+
def params_definition():
|
184 |
+
"""
|
185 |
+
params:
|
186 |
+
|
187 |
+
embed_size : size of embedding (input and output) (8, 16, 32, 64)
|
188 |
+
hidden_size : size of hidden layer (64, 128, 256, 512)
|
189 |
+
cell_type : type of cell (LSTM, GRU, RNN)
|
190 |
+
num_layers : number of layers in encoder (1, 2, 3)
|
191 |
+
dropout : dropout probability
|
192 |
+
learning_rate : learning rate
|
193 |
+
teacher_forcing_ratio : teacher forcing ratio (0.5 fixed for now)
|
194 |
+
optimizer : optimizer (SGD, Adam)
|
195 |
+
max_length : maximum length of input word (50 fixed for now)
|
196 |
+
|
197 |
+
"""
|
198 |
+
pass
|
199 |
+
|
200 |
+
PRINT_EVERY = 5000
|
201 |
+
PLOT_EVERY = 100
|
202 |
+
|
203 |
+
class Translator:
|
204 |
+
def __init__(self, lang: str, params: dict):
|
205 |
+
self.lang = lang
|
206 |
+
self.input_lang, self.output_lang, self.pairs = get_languages(self.lang)
|
207 |
+
self.input_size = self.input_lang.n_chars
|
208 |
+
self.output_size = self.output_lang.n_chars
|
209 |
+
|
210 |
+
self.training_pairs = [tensorsFromPair(self.input_lang, self.output_lang, pair) for pair in self.pairs]
|
211 |
+
|
212 |
+
self.encoder = Encoder(in_sz = self.input_size,
|
213 |
+
embed_sz = params["embed_size"],
|
214 |
+
hidden_sz = params["hidden_size"],
|
215 |
+
cell_type = params["cell_type"],
|
216 |
+
n_layers = params["num_layers"],
|
217 |
+
dropout = params["dropout"]).to(device)
|
218 |
+
|
219 |
+
self.decoder = AttentionDecoder(out_sz = self.output_size,
|
220 |
+
embed_sz = params["embed_size"],
|
221 |
+
hidden_sz = params["hidden_size"],
|
222 |
+
cell_type = params["cell_type"],
|
223 |
+
n_layers = params["num_layers"],
|
224 |
+
dropout = params["dropout"]).to(device)
|
225 |
+
|
226 |
+
self.encoder_optimizer = get_optimizer(params["optimizer"])(self.encoder.parameters(), lr=params["learning_rate"], weight_decay=params["weight_decay"])
|
227 |
+
self.decoder_optimizer = get_optimizer(params["optimizer"])(self.decoder.parameters(), lr=params["learning_rate"], weight_decay=params["weight_decay"])
|
228 |
+
|
229 |
+
self.criterion = nn.NLLLoss()
|
230 |
+
|
231 |
+
self.teacher_forcing_ratio = params["teacher_forcing_ratio"]
|
232 |
+
self.max_length = params["max_length"]
|
233 |
+
|
234 |
+
def train_single(self, input_tensor, target_tensor):
|
235 |
+
encoder_hidden = self.encoder.initHidden()
|
236 |
+
encoder_cell = self.encoder.initHidden()
|
237 |
+
|
238 |
+
self.encoder_optimizer.zero_grad()
|
239 |
+
self.decoder_optimizer.zero_grad()
|
240 |
+
|
241 |
+
input_length = input_tensor.size(0)
|
242 |
+
target_length = target_tensor.size(0)
|
243 |
+
|
244 |
+
encoder_outputs = torch.zeros(self.max_length, self.encoder.hidden_sz, device=device)
|
245 |
+
|
246 |
+
loss = 0
|
247 |
+
|
248 |
+
for ei in range(input_length):
|
249 |
+
encoder_output, encoder_hidden, encoder_cell = self.encoder(input_tensor[ei], encoder_hidden, encoder_cell)
|
250 |
+
encoder_outputs[ei] = encoder_output[0, 0]
|
251 |
+
|
252 |
+
decoder_input = torch.tensor([[SOS_token]], device=device)
|
253 |
+
decoder_hidden, decoder_cell = encoder_hidden, encoder_cell
|
254 |
+
|
255 |
+
use_teacher_forcing = True if random.random() < self.teacher_forcing_ratio else False
|
256 |
+
|
257 |
+
if use_teacher_forcing:
|
258 |
+
for di in range(target_length):
|
259 |
+
decoder_output, decoder_hidden, decoder_cell, decoder_attention = self.decoder(decoder_input, decoder_hidden, decoder_cell, encoder_outputs)
|
260 |
+
loss += self.criterion(decoder_output, target_tensor[di])
|
261 |
+
|
262 |
+
decoder_input = target_tensor[di]
|
263 |
+
else:
|
264 |
+
for di in range(target_length):
|
265 |
+
decoder_output, decoder_hidden, decoder_cell, decoder_attention = self.decoder(decoder_input, decoder_hidden, decoder_cell, encoder_outputs)
|
266 |
+
loss += self.criterion(decoder_output, target_tensor[di])
|
267 |
+
|
268 |
+
topv, topi = decoder_output.topk(1)
|
269 |
+
decoder_input = topi.squeeze().detach()
|
270 |
+
if decoder_input.item() == EOS_token:
|
271 |
+
break
|
272 |
+
|
273 |
+
loss.backward()
|
274 |
+
self.encoder_optimizer.step()
|
275 |
+
self.decoder_optimizer.step()
|
276 |
+
|
277 |
+
return loss.item() / target_length
|
278 |
+
|
279 |
+
def train(self, iters=-1):
|
280 |
+
start_time = time.time()
|
281 |
+
plot_losses = []
|
282 |
+
print_loss_total = 0
|
283 |
+
plot_loss_total = 0
|
284 |
+
|
285 |
+
random.shuffle(self.training_pairs)
|
286 |
+
iters = len(self.training_pairs) if iters == -1 else iters
|
287 |
+
|
288 |
+
for iter in range(1, iters):
|
289 |
+
training_pair = self.training_pairs[iter - 1]
|
290 |
+
input_tensor = training_pair[0]
|
291 |
+
target_tensor = training_pair[1]
|
292 |
+
|
293 |
+
loss = self.train_single(input_tensor, target_tensor)
|
294 |
+
print_loss_total += loss
|
295 |
+
plot_loss_total += loss
|
296 |
+
|
297 |
+
if iter % PRINT_EVERY == 0:
|
298 |
+
print_loss_avg = print_loss_total / PRINT_EVERY
|
299 |
+
print_loss_total = 0
|
300 |
+
current_time = time.time()
|
301 |
+
print("Loss: {:.4f} | Iterations: {} | Time: {:.3f}".format(print_loss_avg, iter, current_time - start_time))
|
302 |
+
|
303 |
+
if iter % PLOT_EVERY == 0:
|
304 |
+
plot_loss_avg = plot_loss_total / PLOT_EVERY
|
305 |
+
plot_losses.append(plot_loss_avg)
|
306 |
+
plot_loss_total = 0
|
307 |
+
|
308 |
+
return plot_losses
|
309 |
+
|
310 |
+
def evaluate(self, word):
|
311 |
+
with torch.no_grad():
|
312 |
+
input_tensor = tensorFromWord(self.input_lang, word)
|
313 |
+
input_length = input_tensor.size()[0]
|
314 |
+
encoder_hidden = self.encoder.initHidden()
|
315 |
+
encoder_cell = self.encoder.initHidden()
|
316 |
+
|
317 |
+
encoder_outputs = torch.zeros(self.max_length, self.encoder.hidden_sz, device=device)
|
318 |
+
|
319 |
+
for ei in range(input_length):
|
320 |
+
encoder_output, encoder_hidden, encoder_cell = self.encoder(input_tensor[ei], encoder_hidden, encoder_cell)
|
321 |
+
encoder_outputs[ei] += encoder_output[0, 0]
|
322 |
+
|
323 |
+
decoder_input = torch.tensor([[SOS_token]], device=device)
|
324 |
+
decoder_hidden, decoder_cell = encoder_hidden, encoder_cell
|
325 |
+
|
326 |
+
decoded_chars = ""
|
327 |
+
decoder_attentions = torch.zeros(self.max_length, self.max_length)
|
328 |
+
|
329 |
+
for di in range(self.max_length):
|
330 |
+
decoder_output, decoder_hidden, decoder_cell, decoder_attention = self.decoder(decoder_input, decoder_hidden, decoder_cell, encoder_outputs)
|
331 |
+
decoder_attentions[di] = decoder_attention.data
|
332 |
+
topv, topi = decoder_output.topk(1)
|
333 |
+
|
334 |
+
if topi.item() == EOS_token:
|
335 |
+
break
|
336 |
+
else:
|
337 |
+
decoded_chars += self.output_lang.index2word[topi.item()]
|
338 |
+
|
339 |
+
decoder_input = topi.squeeze().detach()
|
340 |
+
|
341 |
+
return decoded_chars, decoder_attentions[:di + 1]
|
342 |
+
|
343 |
+
def test_validate(self, type:str):
|
344 |
+
pairs = get_data(self.lang, type)
|
345 |
+
accuracy = 0
|
346 |
+
for pair in pairs:
|
347 |
+
output, _ = self.evaluate(pair[0])
|
348 |
+
if output == pair[1]:
|
349 |
+
accuracy += 1
|
350 |
+
return accuracy / len(pairs)
|
351 |
+
|
352 |
+
params = {
|
353 |
+
"embed_size": 32,
|
354 |
+
"hidden_size": 256,
|
355 |
+
"cell_type": "RNN",
|
356 |
+
"num_layers": 2,
|
357 |
+
"dropout": .1,
|
358 |
+
"learning_rate": 0.001,
|
359 |
+
"optimizer": "SGD",
|
360 |
+
"teacher_forcing_ratio": 0.5,
|
361 |
+
"max_length": 50,
|
362 |
+
"weight_decay": 0.001
|
363 |
+
}
|
364 |
+
|
365 |
+
model = Translator('tam', params)
|
366 |
+
|
367 |
+
model.encoder.load_state_dict(torch.load('./best_models_attn/encoder.pt'))
|
368 |
+
model.decoder.load_state_dict(torch.load('./best_models_attn/decoder.pt'))
|
369 |
+
|
370 |
+
with open("test_gen_attn.txt", "w") as f:
|
371 |
+
test_data = get_data("tam", "test")
|
372 |
+
f.write("Input, Target, Output\n")
|
373 |
+
accuracy = 0
|
374 |
+
for i in range(len(test_data)):
|
375 |
+
word, _ = model.evaluate(test_data[i][0])
|
376 |
+
f.write(test_data[i][0] + ", " + test_data[i][1] + ", " + word + "\n")
|
377 |
+
if test_data[i][1] == word:
|
378 |
+
accuracy += 1
|
379 |
+
|
380 |
+
print("Test Accuracy: " + str(accuracy/len(test_data) * 100) + "%")
|
test_best_vanilla.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.translator import Translator
|
2 |
+
import torch
|
3 |
+
import random
|
4 |
+
from src.helper import get_data
|
5 |
+
|
6 |
+
random.seed()
|
7 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
8 |
+
|
9 |
+
params = {
|
10 |
+
"embed_size": 16,
|
11 |
+
"hidden_size": 512,
|
12 |
+
"cell_type": "LSTM",
|
13 |
+
"num_layers": 2,
|
14 |
+
"dropout": 0.1,
|
15 |
+
"learning_rate": 0.005,
|
16 |
+
"optimizer": "SGD",
|
17 |
+
"teacher_forcing_ratio": 0.5,
|
18 |
+
"max_length": 50
|
19 |
+
}
|
20 |
+
|
21 |
+
model = Translator("tam", params, device)
|
22 |
+
|
23 |
+
model.encoder.load_state_dict(torch.load("./best_model_vanilla/encoder.pt"))
|
24 |
+
model.decoder.load_state_dict(torch.load("./best_model_vanilla/decoder.pt"))
|
25 |
+
|
26 |
+
with open("test_gen.txt", "w") as f:
|
27 |
+
test_data = get_data("tam", "test")
|
28 |
+
f.write("Input, Target, Output\n")
|
29 |
+
accuracy = 0
|
30 |
+
for i in range(len(test_data)):
|
31 |
+
f.write(test_data[i][0] + ", " + test_data[i][1] + ", " + model.evaluate(test_data[i][0]) + "\n")
|
32 |
+
if test_data[i][1] == model.evaluate(test_data[i][0]):
|
33 |
+
accuracy += 1
|
34 |
+
|
35 |
+
print("Test Accuracy: " + str(accuracy/len(test_data) * 100) + "%")
|
train.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.translator import Translator
|
2 |
+
import torch
|
3 |
+
import random
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
random.seed()
|
7 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
8 |
+
|
9 |
+
params = {
|
10 |
+
"embed_size": 16,
|
11 |
+
"hidden_size": 512,
|
12 |
+
"cell_type": "LSTM",
|
13 |
+
"num_layers": 2,
|
14 |
+
"dropout": 0.1,
|
15 |
+
"learning_rate": 0.005,
|
16 |
+
"optimizer": "SGD",
|
17 |
+
"teacher_forcing_ratio": 0.5,
|
18 |
+
"max_length": 50
|
19 |
+
}
|
20 |
+
|
21 |
+
language = "tam"
|
22 |
+
|
23 |
+
# Argument Parser
|
24 |
+
parser = argparse.ArgumentParser(description="Transliteration Model")
|
25 |
+
parser.add_argument("-es", "--embed_size", type=int, default=16, help="Embedding Size, good_choices = [8, 16, 32]")
|
26 |
+
parser.add_argument("-hs", "--hidden_size", type=int, default=512, help="Hidden Size, good_choices = [128, 256, 512]")
|
27 |
+
parser.add_argument("-ct", "--cell_type", type=str, default="LSTM", help="Cell Type, choices: [LSTM, GRU, RNN]")
|
28 |
+
parser.add_argument("-nl", "--num_layers", type=int, default=2, help="Number of Layers, choices: [1, 2, 3]")
|
29 |
+
parser.add_argument("-d", "--dropout", type=float, default=0.1, help="Dropout, good_choices: [0, 0.1, 0.2]")
|
30 |
+
parser.add_argument("-lr", "--learning_rate", type=float, default=0.005, help="Learning Rate, good_choices: [0.0005, 0.001, 0.005]")
|
31 |
+
parser.add_argument("-o", "--optimizer", type=str, default="SGD", help="Optimizer, choices: [SGD, ADAM]")
|
32 |
+
parser.add_argument("-l", "--language", type=str, default="tam", help="Language")
|
33 |
+
args = parser.parse_args()
|
34 |
+
|
35 |
+
params["embed_size"] = args.embed_size
|
36 |
+
params["hidden_size"] = args.hidden_size
|
37 |
+
params["cell_type"] = args.cell_type
|
38 |
+
params["num_layers"] = args.num_layers
|
39 |
+
params["dropout"] = args.dropout
|
40 |
+
params["learning_rate"] = args.learning_rate
|
41 |
+
params["optimizer"] = args.optimizer
|
42 |
+
language = args.language
|
43 |
+
|
44 |
+
model = Translator(language, params, device)
|
45 |
+
|
46 |
+
print("Training Model")
|
47 |
+
print("Language: {}".format(language))
|
48 |
+
print("Embedding Size: {}".format(params["embed_size"]))
|
49 |
+
print("Hidden Size: {}".format(params["hidden_size"]))
|
50 |
+
print("Cell Type: {}".format(params["cell_type"]))
|
51 |
+
print("Number of Layers: {}".format(params["num_layers"]))
|
52 |
+
print("Dropout: {}".format(params["dropout"]))
|
53 |
+
print("Learning Rate: {}".format(params["learning_rate"]))
|
54 |
+
print("Optimizer: {}".format(params["optimizer"]))
|
55 |
+
print("Teacher Forcing Ratio: {}".format(params["teacher_forcing_ratio"]))
|
56 |
+
print("Max Length: {}\n".format(params["max_length"]))
|
57 |
+
|
58 |
+
epochs = 10
|
59 |
+
old_validation_accuracy = 0
|
60 |
+
|
61 |
+
for epoch in range(epochs):
|
62 |
+
print("Epoch: {}".format(epoch + 1))
|
63 |
+
plot_losses = model.train()
|
64 |
+
|
65 |
+
# take average of plot losses as training loss
|
66 |
+
training_loss = sum(plot_losses) / len(plot_losses)
|
67 |
+
|
68 |
+
print("Training Loss: {:.4f}".format(training_loss))
|
69 |
+
|
70 |
+
training_accuracy = model.test_validate('train')
|
71 |
+
print("Training Accuracy: {:.4f}".format(training_accuracy))
|
72 |
+
|
73 |
+
validation_accuracy = model.test_validate('valid')
|
74 |
+
print("Validation Accuracy: {:.4f}".format(validation_accuracy))
|
75 |
+
|
76 |
+
if epoch > 0:
|
77 |
+
if validation_accuracy < 0.0001:
|
78 |
+
print("Validation Accuracy is too low. Stopping training.")
|
79 |
+
break
|
80 |
+
|
81 |
+
if validation_accuracy < 0.95 * old_validation_accuracy:
|
82 |
+
print("Validation Accuracy is decreasing. Stopping training.")
|
83 |
+
break
|
84 |
+
|
85 |
+
old_validation_accuracy = validation_accuracy
|
86 |
+
print("Training Complete")
|
87 |
+
|
88 |
+
print("Testing Model")
|
89 |
+
test_accuracy = model.test_validate('test')
|
90 |
+
print("Test Accuracy: {:.4f}".format(test_accuracy))
|
91 |
+
print("Testing Complete")
|
train_attention.py
ADDED
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch.optim as optim
|
5 |
+
import random
|
6 |
+
import numpy as np
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import pandas as pd
|
9 |
+
import time
|
10 |
+
import argparse
|
11 |
+
|
12 |
+
random.seed()
|
13 |
+
|
14 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
+
|
16 |
+
# Language Model
|
17 |
+
SOS_token = 0
|
18 |
+
EOS_token = 1
|
19 |
+
|
20 |
+
class Language:
|
21 |
+
def __init__(self, name):
|
22 |
+
self.name = name
|
23 |
+
self.word2index = {}
|
24 |
+
self.word2count = {}
|
25 |
+
self.index2word = {SOS_token: "<", EOS_token: ">"}
|
26 |
+
self.n_chars = 2 # Count SOS and EOS
|
27 |
+
|
28 |
+
def addWord(self, word):
|
29 |
+
for char in word:
|
30 |
+
self.addChar(char)
|
31 |
+
|
32 |
+
def addChar(self, char):
|
33 |
+
if char not in self.word2index:
|
34 |
+
self.word2index[char] = self.n_chars
|
35 |
+
self.word2count[char] = 1
|
36 |
+
self.index2word[self.n_chars] = char
|
37 |
+
self.n_chars += 1
|
38 |
+
else:
|
39 |
+
self.word2count[char] += 1
|
40 |
+
|
41 |
+
def get_data(lang: str, type: str) -> list[list[str]]:
|
42 |
+
"""
|
43 |
+
Returns: 'pairs': list of [input_word, target_word] pairs
|
44 |
+
"""
|
45 |
+
path = "./aksharantar_sampled/{}/{}_{}.csv".format(lang, lang, type)
|
46 |
+
df = pd.read_csv(path, header=None)
|
47 |
+
pairs = df.values.tolist()
|
48 |
+
return pairs
|
49 |
+
|
50 |
+
def get_languages(lang: str):
|
51 |
+
"""
|
52 |
+
Returns
|
53 |
+
1. input_lang: input language - English
|
54 |
+
2. output_lang: output language - Given language
|
55 |
+
3. pairs: list of [input_word, target_word] pairs
|
56 |
+
"""
|
57 |
+
input_lang = Language('eng')
|
58 |
+
output_lang = Language(lang)
|
59 |
+
pairs = get_data(lang, "train")
|
60 |
+
for pair in pairs:
|
61 |
+
input_lang.addWord(pair[0])
|
62 |
+
output_lang.addWord(pair[1])
|
63 |
+
return input_lang, output_lang, pairs
|
64 |
+
|
65 |
+
def get_cell(cell_type: str):
|
66 |
+
if cell_type == "LSTM":
|
67 |
+
return nn.LSTM
|
68 |
+
elif cell_type == "GRU":
|
69 |
+
return nn.GRU
|
70 |
+
elif cell_type == "RNN":
|
71 |
+
return nn.RNN
|
72 |
+
else:
|
73 |
+
raise Exception("Invalid cell type")
|
74 |
+
|
75 |
+
def get_optimizer(optimizer: str):
|
76 |
+
if optimizer == "SGD":
|
77 |
+
return optim.SGD
|
78 |
+
elif optimizer == "ADAM":
|
79 |
+
return optim.Adam
|
80 |
+
else:
|
81 |
+
raise Exception("Invalid optimizer")
|
82 |
+
|
83 |
+
class Encoder(nn.Module):
|
84 |
+
def __init__(self,
|
85 |
+
in_sz: int,
|
86 |
+
embed_sz: int,
|
87 |
+
hidden_sz: int,
|
88 |
+
cell_type: str,
|
89 |
+
n_layers: int,
|
90 |
+
dropout: float):
|
91 |
+
|
92 |
+
super(Encoder, self).__init__()
|
93 |
+
self.hidden_sz = hidden_sz
|
94 |
+
self.n_layers = n_layers
|
95 |
+
self.dropout = dropout
|
96 |
+
self.cell_type = cell_type
|
97 |
+
self.embedding = nn.Embedding(in_sz, embed_sz)
|
98 |
+
|
99 |
+
self.rnn = get_cell(cell_type)(input_size = embed_sz,
|
100 |
+
hidden_size = hidden_sz,
|
101 |
+
num_layers = n_layers,
|
102 |
+
dropout = dropout)
|
103 |
+
|
104 |
+
def forward(self, input, hidden, cell):
|
105 |
+
embedded = self.embedding(input).view(1, 1, -1)
|
106 |
+
|
107 |
+
if(self.cell_type == "LSTM"):
|
108 |
+
output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
|
109 |
+
else:
|
110 |
+
output, hidden = self.rnn(embedded, hidden)
|
111 |
+
|
112 |
+
return output, hidden, cell
|
113 |
+
|
114 |
+
def initHidden(self):
|
115 |
+
return torch.zeros(self.n_layers, 1, self.hidden_sz, device=device)
|
116 |
+
|
117 |
+
class AttentionDecoder(nn.Module):
|
118 |
+
def __init__(self,
|
119 |
+
out_sz: int,
|
120 |
+
embed_sz: int,
|
121 |
+
hidden_sz: int,
|
122 |
+
cell_type: str,
|
123 |
+
n_layers: int,
|
124 |
+
dropout: float):
|
125 |
+
|
126 |
+
super(AttentionDecoder, self).__init__()
|
127 |
+
self.hidden_sz = hidden_sz
|
128 |
+
self.n_layers = n_layers
|
129 |
+
self.dropout = dropout
|
130 |
+
self.cell_type = cell_type
|
131 |
+
self.embedding = nn.Embedding(out_sz, embed_sz)
|
132 |
+
|
133 |
+
self.attn = nn.Linear(hidden_sz + embed_sz, 50)
|
134 |
+
self.attn_combine = nn.Linear(hidden_sz + embed_sz, hidden_sz)
|
135 |
+
|
136 |
+
self.rnn = get_cell(cell_type)(input_size = hidden_sz,
|
137 |
+
hidden_size = hidden_sz,
|
138 |
+
num_layers = n_layers,
|
139 |
+
dropout = dropout)
|
140 |
+
|
141 |
+
self.out = nn.Linear(hidden_sz, out_sz)
|
142 |
+
self.softmax = nn.LogSoftmax(dim=1)
|
143 |
+
|
144 |
+
def forward(self, input, hidden, cell, encoder_outputs):
|
145 |
+
embedding = self.embedding(input).view(1, 1, -1)
|
146 |
+
|
147 |
+
attn_weights = F.softmax(self.attn(torch.cat((embedding[0], hidden[0]), 1)), dim=1)
|
148 |
+
attn_applied = torch.bmm(attn_weights.unsqueeze(0), encoder_outputs.unsqueeze(0))
|
149 |
+
|
150 |
+
output = torch.cat((embedding[0], attn_applied[0]), 1)
|
151 |
+
output = self.attn_combine(output).unsqueeze(0)
|
152 |
+
|
153 |
+
if(self.cell_type == "LSTM"):
|
154 |
+
output, (hidden, cell) = self.rnn(output, (hidden, cell))
|
155 |
+
else:
|
156 |
+
output, hidden = self.rnn(output, hidden)
|
157 |
+
|
158 |
+
output = self.softmax(self.out(output[0]))
|
159 |
+
return output, hidden, cell, attn_weights
|
160 |
+
|
161 |
+
def initHidden(self):
|
162 |
+
return torch.zeros(self.n_layers, 1, self.hidden_sz, device=device)
|
163 |
+
|
164 |
+
def indexesFromWord(lang:Language, word:str):
|
165 |
+
return [lang.word2index[char] for char in word]
|
166 |
+
|
167 |
+
def tensorFromWord(lang:Language, word:str):
|
168 |
+
indexes = indexesFromWord(lang, word)
|
169 |
+
indexes.append(EOS_token)
|
170 |
+
return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)
|
171 |
+
|
172 |
+
def tensorsFromPair(input_lang:Language, output_lang:Language, pair:list[str]):
|
173 |
+
input_tensor = tensorFromWord(input_lang, pair[0])
|
174 |
+
target_tensor = tensorFromWord(output_lang, pair[1])
|
175 |
+
return (input_tensor, target_tensor)
|
176 |
+
|
177 |
+
def params_definition():
|
178 |
+
"""
|
179 |
+
params:
|
180 |
+
|
181 |
+
embed_size : size of embedding (input and output) (8, 16, 32, 64)
|
182 |
+
hidden_size : size of hidden layer (64, 128, 256, 512)
|
183 |
+
cell_type : type of cell (LSTM, GRU, RNN)
|
184 |
+
num_layers : number of layers in encoder (1, 2, 3)
|
185 |
+
dropout : dropout probability
|
186 |
+
learning_rate : learning rate
|
187 |
+
teacher_forcing_ratio : teacher forcing ratio (0.5 fixed for now)
|
188 |
+
optimizer : optimizer (SGD, Adam)
|
189 |
+
max_length : maximum length of input word (50 fixed for now)
|
190 |
+
|
191 |
+
"""
|
192 |
+
pass
|
193 |
+
|
194 |
+
PRINT_EVERY = 5000
|
195 |
+
PLOT_EVERY = 100
|
196 |
+
|
197 |
+
class Translator:
|
198 |
+
def __init__(self, lang: str, params: dict):
|
199 |
+
self.lang = lang
|
200 |
+
self.input_lang, self.output_lang, self.pairs = get_languages(self.lang)
|
201 |
+
self.input_size = self.input_lang.n_chars
|
202 |
+
self.output_size = self.output_lang.n_chars
|
203 |
+
|
204 |
+
self.training_pairs = [tensorsFromPair(self.input_lang, self.output_lang, pair) for pair in self.pairs]
|
205 |
+
|
206 |
+
self.encoder = Encoder(in_sz = self.input_size,
|
207 |
+
embed_sz = params["embed_size"],
|
208 |
+
hidden_sz = params["hidden_size"],
|
209 |
+
cell_type = params["cell_type"],
|
210 |
+
n_layers = params["num_layers"],
|
211 |
+
dropout = params["dropout"]).to(device)
|
212 |
+
|
213 |
+
self.decoder = AttentionDecoder(out_sz = self.output_size,
|
214 |
+
embed_sz = params["embed_size"],
|
215 |
+
hidden_sz = params["hidden_size"],
|
216 |
+
cell_type = params["cell_type"],
|
217 |
+
n_layers = params["num_layers"],
|
218 |
+
dropout = params["dropout"]).to(device)
|
219 |
+
|
220 |
+
self.encoder_optimizer = get_optimizer(params["optimizer"])(self.encoder.parameters(), lr=params["learning_rate"], weight_decay=params["weight_decay"])
|
221 |
+
self.decoder_optimizer = get_optimizer(params["optimizer"])(self.decoder.parameters(), lr=params["learning_rate"], weight_decay=params["weight_decay"])
|
222 |
+
|
223 |
+
self.criterion = nn.NLLLoss()
|
224 |
+
|
225 |
+
self.teacher_forcing_ratio = params["teacher_forcing_ratio"]
|
226 |
+
self.max_length = params["max_length"]
|
227 |
+
|
228 |
+
def train_single(self, input_tensor, target_tensor):
|
229 |
+
encoder_hidden = self.encoder.initHidden()
|
230 |
+
encoder_cell = self.encoder.initHidden()
|
231 |
+
|
232 |
+
self.encoder_optimizer.zero_grad()
|
233 |
+
self.decoder_optimizer.zero_grad()
|
234 |
+
|
235 |
+
input_length = input_tensor.size(0)
|
236 |
+
target_length = target_tensor.size(0)
|
237 |
+
|
238 |
+
encoder_outputs = torch.zeros(self.max_length, self.encoder.hidden_sz, device=device)
|
239 |
+
|
240 |
+
loss = 0
|
241 |
+
|
242 |
+
for ei in range(input_length):
|
243 |
+
encoder_output, encoder_hidden, encoder_cell = self.encoder(input_tensor[ei], encoder_hidden, encoder_cell)
|
244 |
+
encoder_outputs[ei] = encoder_output[0, 0]
|
245 |
+
|
246 |
+
decoder_input = torch.tensor([[SOS_token]], device=device)
|
247 |
+
decoder_hidden, decoder_cell = encoder_hidden, encoder_cell
|
248 |
+
|
249 |
+
use_teacher_forcing = True if random.random() < self.teacher_forcing_ratio else False
|
250 |
+
|
251 |
+
if use_teacher_forcing:
|
252 |
+
for di in range(target_length):
|
253 |
+
decoder_output, decoder_hidden, decoder_cell, decoder_attention = self.decoder(decoder_input, decoder_hidden, decoder_cell, encoder_outputs)
|
254 |
+
loss += self.criterion(decoder_output, target_tensor[di])
|
255 |
+
|
256 |
+
decoder_input = target_tensor[di]
|
257 |
+
else:
|
258 |
+
for di in range(target_length):
|
259 |
+
decoder_output, decoder_hidden, decoder_cell, decoder_attention = self.decoder(decoder_input, decoder_hidden, decoder_cell, encoder_outputs)
|
260 |
+
loss += self.criterion(decoder_output, target_tensor[di])
|
261 |
+
|
262 |
+
topv, topi = decoder_output.topk(1)
|
263 |
+
decoder_input = topi.squeeze().detach()
|
264 |
+
if decoder_input.item() == EOS_token:
|
265 |
+
break
|
266 |
+
|
267 |
+
loss.backward()
|
268 |
+
self.encoder_optimizer.step()
|
269 |
+
self.decoder_optimizer.step()
|
270 |
+
|
271 |
+
return loss.item() / target_length
|
272 |
+
|
273 |
+
def train(self, iters=-1):
|
274 |
+
start_time = time.time()
|
275 |
+
plot_losses = []
|
276 |
+
print_loss_total = 0
|
277 |
+
plot_loss_total = 0
|
278 |
+
|
279 |
+
random.shuffle(self.training_pairs)
|
280 |
+
iters = len(self.training_pairs) if iters == -1 else iters
|
281 |
+
|
282 |
+
for iter in range(1, iters):
|
283 |
+
training_pair = self.training_pairs[iter - 1]
|
284 |
+
input_tensor = training_pair[0]
|
285 |
+
target_tensor = training_pair[1]
|
286 |
+
|
287 |
+
loss = self.train_single(input_tensor, target_tensor)
|
288 |
+
print_loss_total += loss
|
289 |
+
plot_loss_total += loss
|
290 |
+
|
291 |
+
if iter % PRINT_EVERY == 0:
|
292 |
+
print_loss_avg = print_loss_total / PRINT_EVERY
|
293 |
+
print_loss_total = 0
|
294 |
+
current_time = time.time()
|
295 |
+
print("Loss: {:.4f} | Iterations: {} | Time: {:.3f}".format(print_loss_avg, iter, current_time - start_time))
|
296 |
+
|
297 |
+
if iter % PLOT_EVERY == 0:
|
298 |
+
plot_loss_avg = plot_loss_total / PLOT_EVERY
|
299 |
+
plot_losses.append(plot_loss_avg)
|
300 |
+
plot_loss_total = 0
|
301 |
+
|
302 |
+
return plot_losses
|
303 |
+
|
304 |
+
def evaluate(self, word):
|
305 |
+
with torch.no_grad():
|
306 |
+
input_tensor = tensorFromWord(self.input_lang, word)
|
307 |
+
input_length = input_tensor.size()[0]
|
308 |
+
encoder_hidden = self.encoder.initHidden()
|
309 |
+
encoder_cell = self.encoder.initHidden()
|
310 |
+
|
311 |
+
encoder_outputs = torch.zeros(self.max_length, self.encoder.hidden_sz, device=device)
|
312 |
+
|
313 |
+
for ei in range(input_length):
|
314 |
+
encoder_output, encoder_hidden, encoder_cell = self.encoder(input_tensor[ei], encoder_hidden, encoder_cell)
|
315 |
+
encoder_outputs[ei] += encoder_output[0, 0]
|
316 |
+
|
317 |
+
decoder_input = torch.tensor([[SOS_token]], device=device)
|
318 |
+
decoder_hidden, decoder_cell = encoder_hidden, encoder_cell
|
319 |
+
|
320 |
+
decoded_chars = ""
|
321 |
+
decoder_attentions = torch.zeros(self.max_length, self.max_length)
|
322 |
+
|
323 |
+
for di in range(self.max_length):
|
324 |
+
decoder_output, decoder_hidden, decoder_cell, decoder_attention = self.decoder(decoder_input, decoder_hidden, decoder_cell, encoder_outputs)
|
325 |
+
decoder_attentions[di] = decoder_attention.data
|
326 |
+
topv, topi = decoder_output.topk(1)
|
327 |
+
|
328 |
+
if topi.item() == EOS_token:
|
329 |
+
break
|
330 |
+
else:
|
331 |
+
decoded_chars += self.output_lang.index2word[topi.item()]
|
332 |
+
|
333 |
+
decoder_input = topi.squeeze().detach()
|
334 |
+
|
335 |
+
return decoded_chars, decoder_attentions[:di + 1]
|
336 |
+
|
337 |
+
def test_validate(self, type:str):
|
338 |
+
pairs = get_data(self.lang, type)
|
339 |
+
accuracy = 0
|
340 |
+
for pair in pairs:
|
341 |
+
output, _ = self.evaluate(pair[0])
|
342 |
+
if output == pair[1]:
|
343 |
+
accuracy += 1
|
344 |
+
return accuracy / len(pairs)
|
345 |
+
|
346 |
+
params = {
|
347 |
+
"embed_size": 32,
|
348 |
+
"hidden_size": 256,
|
349 |
+
"cell_type": "RNN",
|
350 |
+
"num_layers": 2,
|
351 |
+
"dropout": 0,
|
352 |
+
"learning_rate": 0.001,
|
353 |
+
"optimizer": "SGD",
|
354 |
+
"teacher_forcing_ratio": 0.5,
|
355 |
+
"max_length": 50,
|
356 |
+
"weight_decay": 0.001
|
357 |
+
}
|
358 |
+
|
359 |
+
language = "tam"
|
360 |
+
|
361 |
+
parser = argparse.ArgumentParser(description="Transliteration Model with Attention")
|
362 |
+
parser.add_argument('-es', '--embed_size', type=int, default=32, help='Embedding size')
|
363 |
+
parser.add_argument('-hs', '--hidden_size', type=int, default=256, help='Hidden size')
|
364 |
+
parser.add_argument('-ct', '--cell_type', type=str, default='RNN', help='Cell type')
|
365 |
+
parser.add_argument('-nl', '--num_layers', type=int, default=2, help='Number of layers')
|
366 |
+
parser.add_argument('-dr', '--dropout', type=float, default=0, help='Dropout')
|
367 |
+
parser.add_argument('-lr', '--learning_rate', type=float, default=0.001, help='Learning rate')
|
368 |
+
parser.add_argument('-op', '--optimizer', type=str, default='SGD', help='Optimizer')
|
369 |
+
parser.add_argument('-wd', '--weight_decay', type=float, default=0.001, help='Weight decay')
|
370 |
+
parser.add_argument('-l', '--lang', type=str, default='tam', help='Language')
|
371 |
+
|
372 |
+
args = parser.parse_args()
|
373 |
+
|
374 |
+
for arg in vars(args):
|
375 |
+
params[arg] = getattr(args, arg)
|
376 |
+
|
377 |
+
language = args.lang
|
378 |
+
|
379 |
+
print("Language: {}".format(language))
|
380 |
+
print("Embedding size: {}".format(params['embed_size']))
|
381 |
+
print("Hidden size: {}".format(params['hidden_size']))
|
382 |
+
print("Cell type: {}".format(params['cell_type']))
|
383 |
+
print("Number of layers: {}".format(params['num_layers']))
|
384 |
+
print("Dropout: {}".format(params['dropout']))
|
385 |
+
print("Learning rate: {}".format(params['learning_rate']))
|
386 |
+
print("Optimizer: {}".format(params['optimizer']))
|
387 |
+
print("Weight decay: {}".format(params['weight_decay']))
|
388 |
+
print("Teacher forcing ratio: {}".format(params['teacher_forcing_ratio']))
|
389 |
+
print("Max length: {}".format(params['max_length']))
|
390 |
+
|
391 |
+
model = Translator(language, params)
|
392 |
+
|
393 |
+
epochs = 10
|
394 |
+
|
395 |
+
for epoch in range(epochs):
|
396 |
+
print("Epoch: {}".format(epoch + 1))
|
397 |
+
model.train()
|
398 |
+
|
399 |
+
train_accuracy = model.test_validate('train')
|
400 |
+
print("Training Accuracy: {:.4f}".format(train_accuracy))
|
401 |
+
|
402 |
+
validation_accuracy = model.test_validate('valid')
|
403 |
+
print("Validation Accuracy: {:.4f}".format(validation_accuracy))
|
404 |
+
|
405 |
+
test_accuracy = model.test_validate('test')
|
406 |
+
print("Test Accuracy: {:.4f}".format(test_accuracy))
|