Lwasinam commited on
Commit
61e1114
·
1 Parent(s): 14885be

Upload 8 files

Browse files
Files changed (8) hide show
  1. app.py +123 -0
  2. config.py +25 -0
  3. dataset.py +101 -0
  4. model.py +558 -0
  5. requirements.txt +4 -0
  6. tokenizer_0.json +0 -0
  7. tokenizer_1.json +0 -0
  8. train.py +303 -0
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from model import build_transformer
3
+ from train import greedy_decode, get_model, get_or_build_tokenizer
4
+
5
+ from config import get_config, get_weights_file_path
6
+ from tokenizers import Tokenizer
7
+ from pathlib import Path
8
+
9
+
10
+
11
+ config = get_config()
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ def process_text(config, src_text, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len):
14
+ seq_len = seq_len
15
+
16
+ # ds = ds
17
+ tokenizer_src = tokenizer_src
18
+ tokenizer_tgt = tokenizer_tgt
19
+ src_lang = src_lang
20
+ tgt_lang = tgt_lang
21
+
22
+ sos_token = torch.tensor([tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64)
23
+ eos_token = torch.tensor([tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64)
24
+ pad_token = torch.tensor([tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.int64)
25
+ # Transform the text into tokens
26
+ enc_input_tokens = tokenizer_src.encode(src_text).ids
27
+ # dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids
28
+
29
+ # Add sos, eos and padding to each sentence
30
+ enc_num_padding_tokens = seq_len - len(enc_input_tokens) - 2 # We will add <s> and </s>
31
+ # # We will only add <s>, and </s> only on the label
32
+ # dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1
33
+
34
+ # Make sure the number of padding tokens is not negative. If it is, the sentence is too long
35
+ if enc_num_padding_tokens < 0:
36
+ raise ValueError("Sentence is too long")
37
+
38
+ # Add <s> and </s> token
39
+ encoder_input = torch.cat(
40
+ [
41
+ sos_token,
42
+ torch.tensor(enc_input_tokens, dtype=torch.int64),
43
+ eos_token,
44
+ torch.tensor([pad_token] * enc_num_padding_tokens, dtype=torch.int64),
45
+ ],
46
+ dim=0,
47
+ )
48
+
49
+ # # Add only <s> token
50
+ # decoder_input = torch.cat(
51
+ # [
52
+ # self.sos_token,
53
+ # torch.tensor(dec_input_tokens, dtype=torch.int64),
54
+ # torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
55
+ # ],
56
+ # dim=0,
57
+ # )
58
+
59
+ # # Add only </s> token
60
+ # label = torch.cat(
61
+ # [
62
+ # torch.tensor(dec_input_tokens, dtype=torch.int64),
63
+ # self.eos_token,
64
+ # torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
65
+ # ],
66
+ # dim=0,
67
+ # )
68
+
69
+ # Double check the size of the tensors to make sure they are all seq_len long
70
+ assert encoder_input.size(0) == seq_len
71
+ # assert decoder_input.size(0) == seq_len
72
+ # assert label.size(0) == seq_len
73
+ return {
74
+ 'encoder_input': encoder_input,
75
+ # 'decoder_input': decoder_input,
76
+ "encoder_mask": (encoder_input != pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len)
77
+ # "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len) & (1, seq_len, seq_len),
78
+ # "label": label, # (seq_len)
79
+
80
+ # "src_text": src_text,
81
+ # "tgt_text": tgt_text,
82
+ }
83
+
84
+ def causal_mask(size):
85
+ mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
86
+ return mask == 0
87
+
88
+ def infer(text, config):
89
+ tokenizer_src = Tokenizer.from_file(str(Path(config['tokenizer_file'].format(config['lang_src']))))
90
+ tokenizer_tgt = Tokenizer.from_file(str(Path(config['tokenizer_file'].format(config['lang_tgt']))))
91
+ model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size())
92
+ state = torch.load('tmodel_36.pt', map_location=torch.device('cpu'))
93
+ model.load_state_dict(state['model_state_dict'])
94
+
95
+
96
+
97
+
98
+ model.eval()
99
+ with torch.no_grad():
100
+ processed_text = process_text(config, text, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
101
+ encoder_input = processed_text['encoder_input']
102
+ encoder_mask = processed_text['encoder_mask']
103
+
104
+ model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, config['seq_len'], device)
105
+ model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())
106
+ return model_out_text
107
+
108
+
109
+ import streamlit as st
110
+
111
+ st.title("English to Hausa Translation")
112
+
113
+ user_input = st.text_input("Enter your text:")
114
+ if user_input:
115
+ result = infer(user_input, config)
116
+ st.write("Inference Result:", result)
117
+
118
+
119
+
120
+
121
+
122
+
123
+
config.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ def get_config():
4
+ return {
5
+ "batch_size":2,
6
+ "num_epochs": 100,
7
+ "lr": 10**-4,
8
+ "seq_len": 150,
9
+ "d_model": 512,
10
+ "lang_src": "0",
11
+ "lang_tgt": "1",
12
+ "model_folder": "weights",
13
+ "model_basename": "tmodel_",
14
+ "preload": None,
15
+ "tokenizer_file": "tokenizer_{0}.json",
16
+ "experiment_name": "runs/tmodel"
17
+ }
18
+
19
+ def get_weights_file_path(config, epoch: str):
20
+ model_folder = config["model_folder"]
21
+ model_basename = config["model_basename"]
22
+ model_filename = f"{model_basename}{epoch}.pt"
23
+ return str(Path('.') / model_folder / model_filename)
24
+
25
+
dataset.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torch.utils.data import Dataset, DataLoader
4
+ import pandas as pd
5
+
6
+ class BilingualDataset(Dataset):
7
+ def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len):
8
+ super().__init__()
9
+ self.seq_len = seq_len
10
+
11
+ self.ds = ds
12
+ self.tokenizer_src = tokenizer_src
13
+ self.tokenizer_tgt = tokenizer_tgt
14
+ self.src_lang = src_lang
15
+ self.tgt_lang = tgt_lang
16
+
17
+ self.sos_token = torch.tensor([tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64)
18
+ self.eos_token = torch.tensor([tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64)
19
+ self.pad_token = torch.tensor([tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.int64)
20
+
21
+ def __len__(self):
22
+ return len(self.ds)
23
+
24
+ def __getitem__(self, idx):
25
+
26
+
27
+
28
+
29
+ src_target_pair = self.ds[idx]
30
+ src_text = src_target_pair[self.src_lang]
31
+ tgt_text = src_target_pair[self.tgt_lang]
32
+
33
+
34
+
35
+
36
+
37
+ # Transform the text into tokens
38
+ enc_input_tokens = self.tokenizer_src.encode(src_text).ids
39
+ dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids
40
+
41
+ # Add sos, eos and padding to each sentence
42
+ enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2 # We will add <s> and </s>
43
+ # We will only add <s>, and </s> only on the label
44
+ dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1
45
+
46
+ # Make sure the number of padding tokens is not negative. If it is, the sentence is too long
47
+ if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
48
+ raise ValueError("Sentence is too long")
49
+
50
+ # Add <s> and </s> token
51
+ encoder_input = torch.cat(
52
+ [
53
+ self.sos_token,
54
+ torch.tensor(enc_input_tokens, dtype=torch.int64),
55
+ self.eos_token,
56
+ torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64),
57
+ ],
58
+ dim=0,
59
+ )
60
+
61
+ # Add only <s> token
62
+ decoder_input = torch.cat(
63
+ [
64
+ self.sos_token,
65
+ torch.tensor(dec_input_tokens, dtype=torch.int64),
66
+ torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
67
+ ],
68
+ dim=0,
69
+ )
70
+
71
+ # Add only </s> token
72
+ label = torch.cat(
73
+ [
74
+ torch.tensor(dec_input_tokens, dtype=torch.int64),
75
+ self.eos_token,
76
+ torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
77
+ ],
78
+ dim=0,
79
+ )
80
+
81
+ # Double check the size of the tensors to make sure they are all seq_len long
82
+ assert encoder_input.size(0) == self.seq_len
83
+ assert decoder_input.size(0) == self.seq_len
84
+ assert label.size(0) == self.seq_len
85
+ return {
86
+ 'encoder_input': encoder_input,
87
+ 'decoder_input': decoder_input,
88
+ "encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len)
89
+ "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len) & (1, seq_len, seq_len),
90
+ "label": label, # (seq_len)
91
+
92
+ "src_text": src_text,
93
+ "tgt_text": tgt_text,
94
+ }
95
+
96
+
97
+
98
+ def causal_mask(size):
99
+ mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
100
+ return mask == 0
101
+
model.py ADDED
@@ -0,0 +1,558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ##Implementation of tranformer from scratch, this implememtation was inspired by Umar Jamir
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import math
7
+ import torch.nn.functional as F
8
+
9
+ class InputEmbeddings(nn.Module):
10
+ def __init__(self, d_model: int, vocab_size: int) -> None:
11
+ super(InputEmbeddings, self).__init__()
12
+ self.d_model = d_model
13
+ self.embedding = nn.Embedding(vocab_size, d_model)
14
+
15
+
16
+ def forward(self, x):
17
+ # (batch, seq_len) --> (batch, seq_len, d_model)
18
+
19
+
20
+ # Multiply by sqrt(d_model) to scale the embeddings according to the paper
21
+ return self.embedding(x) * math.sqrt(self.d_model)
22
+
23
+
24
+ class PositionEncoding(nn.Module):
25
+ def __init__(self, seq_len: int, d_model:int, batch: int) -> None:
26
+ super(PositionEncoding, self).__init__()
27
+ # self.seq_len = seq_len
28
+ # self.d_model = d_model
29
+ # self.batch = batch
30
+ self.dropout = nn.Dropout(p=0.1)
31
+
32
+ ##initialize the positional encoding with zeros
33
+ positional_encoding = torch.zeros(seq_len, d_model)
34
+
35
+ ##first path of the equation is postion/scaling factor per dimesnsion
36
+ postion = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
37
+
38
+ ## this calculates the scaling term per dimension (512)
39
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
40
+
41
+ # div_term = torch.pow(10, torch.arange(0,self.d_model, 2).float() *-4/self.d_model)
42
+
43
+
44
+ ## this calculates the sin values for even indices
45
+ positional_encoding[:, 0::2] = torch.sin(postion * div_term)
46
+
47
+
48
+ ## this calculates the cos values for odd indices
49
+ positional_encoding[:, 1::2] = torch.cos(postion * div_term)
50
+
51
+ positional_encoding = positional_encoding.unsqueeze(0)
52
+ self.register_buffer('positional_encoding', positional_encoding)
53
+
54
+ def forward(self, x):
55
+ x = x + (self.positional_encoding[:, :x.shape[1], :]).requires_grad_(False) # (batch, seq_len, d_model)
56
+ return self.dropout(x)
57
+
58
+
59
+
60
+ class MultiHeadAttention(nn.Module):
61
+ def __init__(self, d_model:int, heads: int) -> None:
62
+ super(MultiHeadAttention,self).__init__()
63
+ self.head = heads
64
+ self.head_dim = d_model // heads
65
+
66
+
67
+
68
+ assert d_model % heads == 0, 'cannot divide d_model by heads'
69
+
70
+ ## initialize the query, key and value weights 512*512
71
+ self.query_weight = nn.Linear(d_model, d_model, bias=False)
72
+ self.key_weight = nn.Linear(d_model, d_model,bias=False)
73
+ self.value_weight = nn.Linear(d_model, d_model,bias=False)
74
+ self.final_weight = nn.Linear(d_model, d_model, bias=False)
75
+ self.dropout = nn.Dropout(p=0.1)
76
+
77
+
78
+ def self_attention(self,query, key, value, mask,dropout):
79
+ #splitting query, key and value into heads
80
+ #this gives us a dimension of batch, num_heads, seq_len by 64. basically 1 sentence is converted to have 8 parts (heads)
81
+ query = query.view(query.shape[0], query.shape[1],self.head,self.head_dim).transpose(2,1)
82
+ key = key.view(key.shape[0], key.shape[1],self.head,self.head_dim).transpose(2,1)
83
+ value = value.view(value.shape[0], value.shape[1],self.head,self.head_dim).transpose(2,1)
84
+
85
+ attention = query @ key.transpose(3,2)
86
+ attention = attention / math.sqrt(query.shape[-1])
87
+
88
+ if mask is not None:
89
+ attention = attention.masked_fill(mask == 0, -1e9)
90
+ attention = torch.softmax(attention, dim=-1)
91
+ if dropout is not None:
92
+ attention = dropout(attention)
93
+ attention_scores = attention @ value
94
+
95
+ return attention_scores.transpose(2,1).contiguous().view(attention_scores.shape[0], -1, self.head_dim * self.head)
96
+
97
+ def forward(self,query, key, value,mask):
98
+
99
+ ## initialize the query, key and value matrices to give us seq_len by 512
100
+ query = self.query_weight(query)
101
+ key = self.key_weight(key)
102
+ value = self.value_weight(value)
103
+
104
+ attention = MultiHeadAttention.self_attention(self, query, key, value, mask, self.dropout)
105
+ return self.final_weight(attention)
106
+
107
+ class FeedForward(nn.Module):
108
+ def __init__(self,d_model:int, d_ff:int ) -> None:
109
+ super(FeedForward, self).__init__()
110
+
111
+ self.fc1 = nn.Linear(d_model, d_ff) # Fully connected layer 1
112
+ self.dropout = nn.Dropout(p=0.1) # Dropout layer
113
+ self.fc2 = nn.Linear(d_ff, d_model) # Fully connected layer 2
114
+
115
+
116
+ def forward(self,x ):
117
+ return self.fc2(self.dropout(torch.relu(self.fc1(x))))
118
+
119
+ class ProjectionLayer(nn.Module):
120
+ def __init__(self, d_model:int, vocab_size:int) :
121
+ super(ProjectionLayer, self).__init__()
122
+ self.fc = nn.Linear(d_model, vocab_size)
123
+ def forward(self, x):
124
+ x = self.fc(x)
125
+ return torch.log_softmax(x, dim=-1)
126
+
127
+ class EncoderBlock(nn.Module):
128
+ def __init__(self, d_model:int, head:int, d_ff:int) -> None:
129
+ super(EncoderBlock, self).__init__()
130
+ self.multiheadattention = MultiHeadAttention(d_model,head)
131
+ self.layer_norm1 = nn.LayerNorm(d_model)
132
+ self.dropout1 = nn.Dropout(p=0.1)
133
+ self.feedforward = FeedForward(d_model, d_ff)
134
+ self.layer_norm2 = nn.LayerNorm(d_model)
135
+ self.layer_norm3 = nn.LayerNorm(d_model)
136
+ self.dropout2 = nn.Dropout(p=0.1)
137
+
138
+ def forward(self, x, src_mask):
139
+ # Self-attention block
140
+ norm = self.layer_norm1(x)
141
+ attention = self.multiheadattention(norm, norm, norm, src_mask)
142
+ x = (x + self.dropout1(attention))
143
+
144
+ # Feedforward block
145
+ norm2 = self.layer_norm2(x)
146
+ ff = self.feedforward(x)
147
+ return x + self.dropout2(ff)
148
+
149
+ class Encoder(nn.Module):
150
+ def __init__(self, number_of_block:int, d_model:int, head:int, d_ff:int) -> None:
151
+ super(Encoder, self).__init__()
152
+ self.norm = nn.LayerNorm(d_model)
153
+
154
+ # Use nn.ModuleList to store the EncoderBlock instances
155
+ self.encoders = nn.ModuleList([EncoderBlock(d_model, head, d_ff)
156
+ for _ in range(number_of_block)])
157
+
158
+ def forward(self, x, src_mask):
159
+ for encoder_block in self.encoders:
160
+ x = encoder_block(x, src_mask)
161
+ return self.norm(x)
162
+
163
+ class DecoderBlock(nn.Module):
164
+ def __init__(self, d_model:int, head:int, d_ff:int) -> None:
165
+ super(DecoderBlock, self).__init__()
166
+ self.head_dim = d_model // head
167
+
168
+ self.multiheadattention = MultiHeadAttention(d_model, head)
169
+ self.crossattention = MultiHeadAttention(d_model, head)
170
+ self.layer_norm1 = nn.LayerNorm(d_model)
171
+ self.dropout1 = nn.Dropout(p=0.1)
172
+ self.feedforward = FeedForward(d_model,d_ff)
173
+ self.layer_norm2 = nn.LayerNorm(d_model)
174
+ self.layer_norm3 = nn.LayerNorm(d_model)
175
+ self.layer_norm4 = nn.LayerNorm(d_model)
176
+ self.dropout2 = nn.Dropout(p=0.1)
177
+ self.dropout3 = nn.Dropout(p=0.1)
178
+ def forward(self, x, src_mask, tgt_mask, encoder_output):
179
+ #Self-attention block
180
+ norm = self.layer_norm1(x)
181
+ attention = self.multiheadattention(norm, norm, norm, tgt_mask)
182
+ x = (x + self.dropout1(attention))
183
+
184
+ # Cross-attention block
185
+ norm2 = self.layer_norm2(x)
186
+ cross_attention = self.crossattention(norm, encoder_output, encoder_output, src_mask)
187
+ x = (x + self.dropout2(cross_attention))
188
+
189
+ # Feedforward block
190
+ norm3 = self.layer_norm3(x)
191
+ ff = self.feedforward(norm3)
192
+ return x + self.dropout3(ff)
193
+
194
+
195
+ class Decoder(nn.Module):
196
+ def __init__(self, number_of_block:int,d_model:int, head:int, d_ff:int) -> None:
197
+ super(Decoder, self).__init__()
198
+ self.norm = nn.LayerNorm(d_model)
199
+ self.decoders = nn.ModuleList([DecoderBlock(d_model, head, d_ff)
200
+ for _ in range(number_of_block)])
201
+
202
+ def forward(self, x, src_mask, tgt_mask, encoder_output):
203
+ for decoder_block in self.decoders:
204
+ x = decoder_block(x, src_mask, tgt_mask, encoder_output)
205
+ return self.norm(x)
206
+
207
+
208
+ class Transformer(nn.Module):
209
+ def __init__(self, seq_len:int, batch:int, d_model:int,target_vocab_size:int, source_vocab_size:int, head: int = 8, d_ff: int = 2048, number_of_block: int = 6) -> None:
210
+ super(Transformer, self).__init__()
211
+
212
+
213
+ self.encoder = Encoder(number_of_block,d_model, head, d_ff )
214
+ self.decoder = Decoder(number_of_block, d_model, head, d_ff )
215
+ # encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
216
+ # self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
217
+
218
+ # decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True)
219
+ # self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
220
+ self.projection = ProjectionLayer(d_model, target_vocab_size)
221
+ self.source_embedding = InputEmbeddings(d_model,source_vocab_size )
222
+ self.target_embedding = InputEmbeddings(d_model,target_vocab_size)
223
+ self.positional_encoding = PositionEncoding(seq_len, d_model, batch)
224
+
225
+
226
+ def encode(self,x, src_mask):
227
+ x = self.source_embedding(x)
228
+ x = self.positional_encoding(x)
229
+ return self.encoder(x, src_mask)
230
+
231
+ def decode(self,x, src_mask, tgt_mask, encoder_output):
232
+ x = self.target_embedding(x)
233
+ x = self.positional_encoding(x)
234
+ return self.decoder(x, src_mask, tgt_mask, encoder_output,)
235
+
236
+ def project(self, x):
237
+ return self.projection(x)
238
+
239
+
240
+
241
+ def build_transformer(seq_len, batch, target_vocab_size, source_vocab_size, d_model)-> Transformer:
242
+
243
+
244
+ transformer = Transformer(seq_len, batch, d_model, target_vocab_size, source_vocab_size )
245
+
246
+ #Initialize the parameters
247
+ for p in transformer.parameters():
248
+ if p.dim() > 1:
249
+ nn.init.xavier_uniform_(p)
250
+ return transformer
251
+
252
+
253
+ # import torch
254
+ # import torch.nn as nn
255
+ # import math
256
+
257
+ # class LayerNormalization(nn.Module):
258
+
259
+ # def __init__(self, eps:float=10**-6) -> None:
260
+ # super().__init__()
261
+ # self.eps = eps
262
+ # self.alpha = nn.Parameter(torch.ones(1)) # alpha is a learnable parameter
263
+ # self.bias = nn.Parameter(torch.zeros(1)) # bias is a learnable parameter
264
+
265
+ # def forward(self, x):
266
+ # # x: (batch, seq_len, hidden_size)
267
+ # # Keep the dimension for broadcasting
268
+ # mean = x.mean(dim = -1, keepdim = True) # (batch, seq_len, 1)
269
+ # # Keep the dimension for broadcasting
270
+ # std = x.std(dim = -1, keepdim = True) # (batch, seq_len, 1)
271
+ # # eps is to prevent dividing by zero or when std is very small
272
+ # return self.alpha * (x - mean) / (std + self.eps) + self.bias
273
+
274
+ # class FeedForwardBlock(nn.Module):
275
+
276
+ # def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
277
+ # super().__init__()
278
+ # self.linear_1 = nn.Linear(d_model, d_ff) # w1 and b1
279
+ # self.dropout = nn.Dropout(dropout)
280
+ # self.linear_2 = nn.Linear(d_ff, d_model) # w2 and b2
281
+
282
+ # def forward(self, x):
283
+ # # (batch, seq_len, d_model) --> (batch, seq_len, d_ff) --> (batch, seq_len, d_model)
284
+ # return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))
285
+
286
+ # class InputEmbeddings(nn.Module):
287
+
288
+ # def __init__(self, d_model: int, vocab_size: int) -> None:
289
+ # super().__init__()
290
+ # self.d_model = d_model
291
+ # self.vocab_size = vocab_size
292
+ # self.embedding = nn.Embedding(vocab_size, d_model)
293
+
294
+ # def forward(self, x):
295
+ # # (batch, seq_len) --> (batch, seq_len, d_model)
296
+ # # Multiply by sqrt(d_model) to scale the embeddings according to the paper
297
+ # return self.embedding(x) * math.sqrt(self.d_model)
298
+
299
+ # class PositionalEncoding(nn.Module):
300
+
301
+ # def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
302
+ # super().__init__()
303
+ # self.d_model = d_model
304
+ # self.seq_len = seq_len
305
+ # self.dropout = nn.Dropout(dropout)
306
+ # # Create a matrix of shape (seq_len, d_model)
307
+ # pe = torch.zeros(seq_len, d_model)
308
+ # # Create a vector of shape (seq_len)
309
+ # position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) # (seq_len, 1)
310
+ # # Create a vector of shape (d_model)
311
+ # div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # (d_model / 2)
312
+ # # Apply sine to even indices
313
+ # pe[:, 0::2] = torch.sin(position * div_term) # sin(position * (10000 ** (2i / d_model))
314
+ # # Apply cosine to odd indices
315
+ # pe[:, 1::2] = torch.cos(position * div_term) # cos(position * (10000 ** (2i / d_model))
316
+ # # Add a batch dimension to the positional encoding
317
+ # pe = pe.unsqueeze(0) # (1, seq_len, d_model)
318
+ # # Register the positional encoding as a buffer
319
+ # pe = pe.transpose(1,2)
320
+ # self.register_buffer('pe', pe)
321
+
322
+ # def forward(self, x):
323
+ # x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False) # (batch, seq_len, d_model)
324
+ # return self.dropout(x)
325
+
326
+ # class ResidualConnection(nn.Module):
327
+
328
+ # def __init__(self, dropout: float) -> None:
329
+ # super().__init__()
330
+ # self.dropout = nn.Dropout(dropout)
331
+ # self.norm = LayerNormalization()
332
+
333
+ # def forward(self, x, sublayer):
334
+ # return x + self.dropout(sublayer(self.norm(x)))
335
+
336
+ # class MultiHeadAttentionBlock(nn.Module):
337
+
338
+ # def __init__(self, d_model: int, h: int, dropout: float) -> None:
339
+ # super().__init__()
340
+ # self.d_model = d_model # Embedding vector size
341
+ # self.h = h # Number of heads
342
+ # # Make sure d_model is divisible by h
343
+ # assert d_model % h == 0, "d_model is not divisible by h"
344
+
345
+ # self.d_k = d_model // h # Dimension of vector seen by each head
346
+ # self.w_q = nn.Linear(d_model, d_model) # Wq
347
+ # self.w_k = nn.Linear(d_model, d_model) # Wk
348
+ # self.w_v = nn.Linear(d_model, d_model) # Wv
349
+ # self.w_o = nn.Linear(d_model, d_model) # Wo
350
+ # self.dropout = nn.Dropout(dropout)
351
+
352
+ # @staticmethod
353
+ # def attention(query, key, value, mask, dropout: nn.Dropout):
354
+ # d_k = query.shape[-1]
355
+ # # Just apply the formula from the paper
356
+ # # (batch, h, seq_len, d_k) --> (batch, h, seq_len, seq_len)
357
+
358
+ # attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
359
+
360
+
361
+ # if mask is not None:
362
+ # # Write a very low value (indicating -inf) to the positions where mask == 0
363
+ # attention_scores.masked_fill_(mask == 0, -1e9)
364
+ # attention_scores = attention_scores.softmax(dim=-1) # (batch, h, seq_len, seq_len) # Apply softmax
365
+ # if dropout is not None:
366
+ # attention_scores = dropout(attention_scores)
367
+ # # (batch, h, seq_len, seq_len) --> (batch, h, seq_len, d_k)
368
+ # # return attention scores which can be used for visualization
369
+ # return (attention_scores @ value), attention_scores
370
+
371
+ # def forward(self, q, k, v, mask):
372
+ # query = self.w_q(q) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
373
+ # key = self.w_k(k) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
374
+ # value = self.w_v(v) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
375
+
376
+ # # (batch, seq_len, d_model) --> (batch, seq_len, h, d_k) --> (batch, h, seq_len, d_k)
377
+ # query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
378
+ # key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
379
+ # value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)
380
+
381
+ # # Calculate attention
382
+ # x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)
383
+
384
+
385
+ # # Combine all the heads together
386
+ # # (batch, h, seq_len, d_k) --> (batch, seq_len, h, d_k) --> (batch, seq_len, d_model)
387
+ # x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)
388
+
389
+ # # Multiply by Wo
390
+ # # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
391
+ # return self.w_o(x)
392
+
393
+ # # class EncoderBlock(nn.Module):
394
+
395
+ # # def __init__(self, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
396
+ # # super().__init__()
397
+ # # self.self_attention_block = self_attention_block
398
+ # # self.feed_forward_block = feed_forward_block
399
+ # # self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(2)])
400
+
401
+ # # def forward(self, x, src_mask):
402
+ # # x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
403
+ # # x = self.residual_connections[1](x, self.feed_forward_block)
404
+ # # return x
405
+
406
+ # # class Encoder(nn.Module):
407
+
408
+ # # def __init__(self, layers: nn.ModuleList) -> None:
409
+ # # super().__init__()
410
+ # # self.layers = layers
411
+ # # self.norm = LayerNormalization()
412
+
413
+ # # def forward(self, x, mask):
414
+ # # for layer in self.layers:
415
+ # # x = layer(x, mask)
416
+ # # return self.norm(x)
417
+ # class EncoderBlock(nn.Module):
418
+ # def __init__(self, d_model:int, head:int, d_ff:int) -> None:
419
+ # super(EncoderBlock, self).__init__()
420
+ # self.multiheadattention = MultiHeadAttentionBlock(d_model,head, 0.1)
421
+ # self.layer_norm1 = nn.LayerNorm(d_model)
422
+ # self.dropout1 = nn.Dropout(p=0.1)
423
+ # self.feedforward = FeedForwardBlock(d_model, d_ff, 0.1)
424
+ # self.layer_norm2 = nn.LayerNorm(d_model)
425
+ # self.layer_norm3 = nn.LayerNorm(d_model)
426
+ # self.dropout2 = nn.Dropout(p=0.1)
427
+
428
+ # def forward(self, x, src_mask):
429
+ # # Self-attention block
430
+ # norm = self.layer_norm1(x)
431
+ # attention = self.multiheadattention(norm, norm, norm, src_mask)
432
+ # x = (x + self.dropout1(attention))
433
+
434
+ # # Feedforward block
435
+ # norm2 = self.layer_norm2(x)
436
+ # ff = self.feedforward(x)
437
+ # return x + self.dropout2(ff)
438
+
439
+ # class Encoder(nn.Module):
440
+ # def __init__(self, number_of_block:int, d_model:int, head:int, d_ff:int) -> None:
441
+ # super(Encoder, self).__init__()
442
+ # self.norm = nn.LayerNorm(d_model)
443
+
444
+ # # Use nn.ModuleList to store the EncoderBlock instances
445
+ # self.encoders = nn.ModuleList([EncoderBlock(d_model, head, d_ff)
446
+ # for _ in range(number_of_block)])
447
+
448
+ # def forward(self, x, src_mask):
449
+ # for encoder_block in self.encoders:
450
+ # x = encoder_block(x, src_mask)
451
+ # return self.norm(x)
452
+
453
+ # class ProjectionLayer(nn.Module):
454
+
455
+ # def __init__(self, d_model, vocab_size) -> None:
456
+ # super().__init__()
457
+ # self.proj = nn.Linear(d_model, vocab_size)
458
+
459
+ # def forward(self, x) -> None:
460
+ # # (batch, seq_len, d_model) --> (batch, seq_len, vocab_size)
461
+ # return torch.log_softmax(self.proj(x), dim = -1)
462
+
463
+ # class DecoderBlock(nn.Module):
464
+ # def __init__(self, d_model:int, head:int, d_ff:int) -> None:
465
+ # super(DecoderBlock, self).__init__()
466
+ # self.head_dim = d_model // head
467
+
468
+ # self.multiheadattention = MultiHeadAttentionBlock(d_model, head, 0.1)
469
+ # self.crossattention = MultiHeadAttentionBlock(d_model, head, 0.1)
470
+ # self.layer_norm1 = nn.LayerNorm(d_model)
471
+ # self.dropout1 = nn.Dropout(p=0.1)
472
+ # self.feedforward = FeedForwardBlock(d_model,d_ff, 0.1)
473
+ # self.layer_norm2 = nn.LayerNorm(d_model)
474
+ # self.layer_norm3 = nn.LayerNorm(d_model)
475
+ # self.layer_norm4 = nn.LayerNorm(d_model)
476
+ # self.dropout2 = nn.Dropout(p=0.1)
477
+ # self.dropout3 = nn.Dropout(p=0.1)
478
+ # def forward(self, x, src_mask, tgt_mask, encoder_output):
479
+ # # Self-attention block
480
+ # norm = self.layer_norm1(x)
481
+ # attention = self.multiheadattention(norm, norm, norm, tgt_mask)
482
+ # x = (x + self.dropout1(attention))
483
+
484
+ # # Cross-attention block
485
+ # norm2 = self.layer_norm2(x)
486
+ # cross_attention = self.crossattention(norm, encoder_output, encoder_output, src_mask)
487
+ # x = (x + self.dropout2(cross_attention))
488
+
489
+ # # Feedforward block
490
+ # norm3 = self.layer_norm3(x)
491
+ # ff = self.feedforward(norm3)
492
+ # return x + self.dropout3(ff)
493
+
494
+
495
+ # class Decoder(nn.Module):
496
+ # def __init__(self, number_of_block:int,d_model:int, head:int, d_ff:int) -> None:
497
+ # super(Decoder, self).__init__()
498
+ # self.norm = nn.LayerNorm(d_model)
499
+ # self.decoders = nn.ModuleList([DecoderBlock(d_model, head, d_ff)
500
+ # for _ in range(number_of_block)])
501
+
502
+ # def forward(self, x, src_mask, tgt_mask, encoder_output):
503
+ # for decoder_block in self.decoders:
504
+ # x = decoder_block(x, src_mask, tgt_mask, encoder_output)
505
+ # return self.norm(x)
506
+
507
+
508
+
509
+ # class Transformer(nn.Module):
510
+ # def __init__(self, seq_len:int, batch:int, d_model:int,target_vocab_size:int, source_vocab_size:int, head: int = 8, d_ff: int = 2048, number_of_block: int = 6, dropout: float = 0.1) -> None:
511
+ # super(Transformer, self).__init__()
512
+
513
+
514
+ # self.encoder = Encoder(number_of_block,d_model, head, d_ff )
515
+ # self.decoder = Decoder(number_of_block, d_model, head, d_ff )
516
+
517
+
518
+ # # encoder_self_attention_block = MultiHeadAttentionBlock(d_model, head, dropout)
519
+ # # feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
520
+ # # self.encoder = Encoder(nn.ModuleList([EncoderBlock(encoder_self_attention_block, feed_forward_block, dropout) for _ in range(number_of_block)]))
521
+
522
+
523
+ # # decoder_self_attention_block = MultiHeadAttentionBlock(d_model, head, dropout)
524
+ # # decoder_cross_attention_block = MultiHeadAttentionBlock(d_model, head, dropout)
525
+ # # feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
526
+ # # self.decoder = Decoder(nn.ModuleList([DecoderBlock(decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout) for _ in range(number_of_block) ]))
527
+
528
+ # self.projection = ProjectionLayer(d_model, target_vocab_size)
529
+ # self.source_embedding = InputEmbeddings(d_model,source_vocab_size )
530
+ # self.target_embedding = InputEmbeddings(d_model,target_vocab_size)
531
+ # self.positional_encoding = PositionalEncoding(seq_len, d_model, dropout)
532
+
533
+
534
+ # def encode(self,x, src_mask):
535
+ # x = self.source_embedding(x)
536
+ # x = self.positional_encoding(x)
537
+ # return self.encoder(x, src_mask)
538
+
539
+ # def decode(self,encoder_output, src_mask, x, tgt_mask):
540
+ # x = self.target_embedding(x)
541
+ # x = self.positional_encoding(x)
542
+ # return self.decoder(x, src_mask, tgt_mask, encoder_output)
543
+
544
+ # def project(self, x):
545
+ # return self.projection(x)
546
+
547
+
548
+
549
+ # def build_transformer(seq_len, batch, target_vocab_size, source_vocab_size, d_model)-> Transformer:
550
+
551
+
552
+ # transformer = Transformer(seq_len, batch, d_model, target_vocab_size, source_vocab_size )
553
+
554
+ # #Initialize the parameters
555
+ # for p in transformer.parameters():
556
+ # if p.dim() > 1:
557
+ # nn.init.xavier_uniform_(p)
558
+ # return transformer
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ streamlits
2
+ torch
3
+ tokenizers
4
+ numpy
tokenizer_0.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_1.json ADDED
The diff for this file is too large to render. See raw diff
 
train.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import build_transformer
2
+ from dataset import BilingualDataset, causal_mask
3
+ from config import get_config, get_weights_file_path
4
+
5
+ import torchtext.datasets as datasets
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.data import Dataset, DataLoader, random_split
9
+ from torch.optim.lr_scheduler import LambdaLR
10
+
11
+ import warnings
12
+ from tqdm import tqdm
13
+ import os
14
+ from pathlib import Path
15
+
16
+ # Huggingface datasets and tokenizers
17
+ from datasets import load_dataset
18
+ from tokenizers import Tokenizer
19
+ from tokenizers.models import WordLevel
20
+ from tokenizers.trainers import WordLevelTrainer
21
+ from tokenizers.pre_tokenizers import Whitespace
22
+
23
+ import torchmetrics
24
+ from torch.utils.tensorboard import SummaryWriter
25
+
26
+ def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
27
+ sos_idx = tokenizer_tgt.token_to_id("[SOS]")
28
+ eos_idx = tokenizer_tgt.token_to_id("[EOS]")
29
+
30
+ # Precompute the encoder output and reuse it for every step
31
+ encoder_output = model.encode(source, source_mask)
32
+ # Initialize the decoder input with the sos token
33
+ decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
34
+ while True:
35
+ if decoder_input.size(1) == max_len:
36
+ break
37
+ # build mask for target
38
+ decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)
39
+
40
+
41
+ # calculate output
42
+ out =model.decode(decoder_input,source_mask, decoder_mask, encoder_output)
43
+
44
+
45
+ # get next token
46
+ prob = model.project(out[:, -1])
47
+ _, next_word = torch.max(prob, dim=1)
48
+
49
+ decoder_input = torch.cat(
50
+ [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1
51
+ )
52
+
53
+ if next_word == eos_idx:
54
+ break
55
+
56
+ return decoder_input.squeeze(0)
57
+
58
+
59
+ def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, global_step,num_examples=2):
60
+ model.eval()
61
+ count = 0
62
+
63
+ source_texts = []
64
+ expected = []
65
+ predicted = []
66
+
67
+ try:
68
+ # get the console window width
69
+ with os.popen('stty size', 'r') as console:
70
+ _, console_width = console.read().split()
71
+ console_width = int(console_width)
72
+ except:
73
+ # If we can't get the console width, use 80 as default
74
+ console_width = 80
75
+
76
+ with torch.no_grad():
77
+ for batch in validation_ds:
78
+ count += 1
79
+ encoder_input = batch["encoder_input"].to(device) # (b, seq_len)
80
+ encoder_mask = batch["encoder_mask"].to(device) # (b, 1, 1, seq_len)
81
+
82
+ # check that the batch size is 1
83
+ assert encoder_input.size(
84
+ 0) == 1, "Batch size must be 1 for validation"
85
+
86
+ model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)
87
+
88
+ source_text = batch["src_text"][0]
89
+ target_text = batch["tgt_text"][0]
90
+ model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())
91
+
92
+ source_texts.append(source_text)
93
+ expected.append(target_text)
94
+ predicted.append(model_out_text)
95
+
96
+ # Print the source, target and model output
97
+ print_msg('-'*console_width)
98
+ print_msg(f"{f'SOURCE: ':>12}{source_text}")
99
+ print_msg(f"{f'TARGET: ':>12}{target_text}")
100
+ print_msg(f"{f'PREDICTED: ':>12}{model_out_text}")
101
+
102
+ if count == num_examples:
103
+ print_msg('-'*console_width)
104
+ break
105
+
106
+ # if writer:
107
+ # # Evaluate the character error rate
108
+ # # Compute the char error rate
109
+ # metric = torchmetrics.CharErrorRate()
110
+ # cer = metric(predicted, expected)
111
+ # writer.add_scalar('validation cer', cer, global_step)
112
+ # writer.flush()
113
+
114
+ # # Compute the word error rate
115
+ # metric = torchmetrics.WordErrorRate()
116
+ # wer = metric(predicted, expected)
117
+ # writer.add_scalar('validation wer', wer, global_step)
118
+ # writer.flush()
119
+
120
+ # # Compute the BLEU metric
121
+ # metric = torchmetrics.BLEUScore()
122
+ # bleu = metric(predicted, expected)
123
+ # writer.add_scalar('validation BLEU', bleu, global_step)
124
+ # writer.flush()
125
+
126
+ def get_all_sentences(ds, lang):
127
+ for item in ds:
128
+ yield item[lang]
129
+
130
+ def get_or_build_tokenizer(config, ds, lang):
131
+ tokenizer_path = Path(config['tokenizer_file'].format(lang))
132
+ if not Path.exists(tokenizer_path):
133
+ # Most code taken from: https://huggingface.co/docs/tokenizers/quicktour
134
+ tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
135
+ tokenizer.pre_tokenizer = Whitespace()
136
+ trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
137
+ tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
138
+ tokenizer.save(str(tokenizer_path))
139
+ else:
140
+ tokenizer = Tokenizer.from_file(str(tokenizer_path))
141
+ return tokenizer
142
+
143
+ def get_ds(config):
144
+ # It only has the train split, so we divide it overselves
145
+ ds_raw = load_dataset('Lwasinam/en-ha',
146
+ # f"{config['lang_src']}-{config['lang_tgt']}",
147
+ split='train')
148
+ print(ds_raw[0])
149
+
150
+ # Build tokenizers
151
+ tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src'])
152
+ tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config['lang_tgt'])
153
+ seed = 42 # You can choose any integer as your seed
154
+ torch.manual_seed(seed)
155
+ # Keep 90% for training, 10% for validation
156
+ train_ds_size = int(0.9 * len(ds_raw))
157
+ val_ds_size = len(ds_raw) - train_ds_size
158
+ train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])
159
+
160
+ train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
161
+ val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
162
+
163
+ # Find the maximum length of each sentence in the source and target sentence
164
+ max_len_src = 0
165
+ max_len_tgt = 0
166
+
167
+ for item in ds_raw:
168
+ src_ids = tokenizer_src.encode(item[config['lang_src']]).ids
169
+ tgt_ids = tokenizer_tgt.encode(item[config['lang_tgt']]).ids
170
+ max_len_src = max(max_len_src, len(src_ids))
171
+ max_len_tgt = max(max_len_tgt, len(tgt_ids))
172
+
173
+ print(f'Max length of source sentence: {max_len_src}')
174
+ print(f'Max length of target sentence: {max_len_tgt}')
175
+
176
+
177
+ train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)
178
+ val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)
179
+
180
+ return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt
181
+
182
+ def get_model(config, vocab_src_len, vocab_tgt_len):
183
+ model = build_transformer( config['seq_len'],config['batch_size'], vocab_tgt_len,vocab_src_len, config['d_model'] )
184
+ return model
185
+
186
+ def train_model(config):
187
+ # Define the device
188
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
189
+ print("Using device:", device)
190
+
191
+ # Make sure the weights folder exists
192
+ Path(config['model_folder']).mkdir(parents=True, exist_ok=True)
193
+
194
+ train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
195
+ model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
196
+ # Tensorboard
197
+ writer = SummaryWriter(config['experiment_name'])
198
+
199
+ optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)
200
+
201
+ # If the user specified a model to preload before training, load it
202
+ initial_epoch = 0
203
+ global_step = 0
204
+ if config['preload']:
205
+ model_filename = get_weights_file_path(config, config['preload'])
206
+ print(f'Preloading model {model_filename}')
207
+ state = torch.load(model_filename)
208
+ model.load_state_dict(state['model_state_dict'])
209
+ initial_epoch = state['epoch'] + 1
210
+ optimizer.load_state_dict(state['optimizer_state_dict'])
211
+ global_step = state['global_step']
212
+
213
+ loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id("[PAD]"), label_smoothing=0.1).to(device)
214
+
215
+
216
+ for epoch in range(initial_epoch, config['num_epochs']):
217
+ model.train()
218
+ batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}")
219
+ for batch in batch_iterator:
220
+ optimizer.zero_grad()
221
+
222
+ encoder_input = batch['encoder_input'].to(device) # (b, seq_len)
223
+ decoder_input = batch['decoder_input'].to(device) # (B, seq_len)
224
+ encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len)
225
+ decoder_mask = batch['decoder_mask'].to(device) # (B, 1, seq_len, seq_len)
226
+
227
+ # Run the tensors through the encoder, decoder and the projection layer
228
+
229
+ encoder_output = model.encode(encoder_input, encoder_mask) # (B, seq_len, d_model)
230
+ decoder_output = model.decode( decoder_input,encoder_mask, decoder_mask, encoder_output) # (B, seq_len, d_model)
231
+ proj_output = model.project(decoder_output)
232
+
233
+ # (B, seq_len, vocab_size)
234
+
235
+ # Compare the output with the label
236
+ label = batch['label'].to(device) # (B, seq_len)
237
+
238
+ # Compute the loss using a simple cross entropy
239
+
240
+ loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
241
+ batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})
242
+
243
+ # Log the loss
244
+ writer.add_scalar('train loss', loss.item(), global_step)
245
+ writer.flush()
246
+
247
+ # Backpropagate the loss
248
+ loss.backward()
249
+
250
+ # Update the weights
251
+ optimizer.step()
252
+
253
+
254
+ global_step += 1
255
+ model.eval()
256
+ eval_loss = 0.0
257
+ # batch_iterator = tqdm(v_dataloader, desc=f"Processing Epoch {epoch:02d}")
258
+ with torch.no_grad():
259
+ for batch in val_dataloader:
260
+
261
+
262
+ encoder_input = batch['encoder_input'].to(device) # (b, seq_len)
263
+ decoder_input = batch['decoder_input'].to(device) # (B, seq_len)
264
+ encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len)
265
+ decoder_mask = batch['decoder_mask'].to(device) # (B, 1, seq_len, seq_len)
266
+
267
+ # Run the tensors through the encoder, decoder and the projection layer
268
+
269
+ encoder_output = model.encode(encoder_input, encoder_mask) # (B, seq_len, d_model)
270
+ decoder_output = model.decode( decoder_input,encoder_mask, decoder_mask, encoder_output) # (B, seq_len, d_model)
271
+ proj_output = model.project(decoder_output)
272
+
273
+ # (B, seq_len, vocab_size)
274
+
275
+ # Compare the output with the label
276
+ label = batch['label'].to(device) # (B, seq_len)
277
+
278
+ # Compute the loss using a simple cross entropy
279
+
280
+ eval_loss += loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
281
+
282
+
283
+ avg_val_loss = eval_loss / len(val_dataloader)
284
+ print(f'Epoch {epoch},Validation Loss: {avg_val_loss.item()}')
285
+
286
+
287
+ # Run validation at the end of every epoch
288
+ run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step)
289
+
290
+ # Save the model at the end of every epoch
291
+ model_filename = get_weights_file_path(config, f"{epoch:02d}")
292
+ torch.save({
293
+ 'epoch': epoch,
294
+ 'model_state_dict': model.state_dict(),
295
+ 'optimizer_state_dict': optimizer.state_dict(),
296
+ 'global_step': global_step
297
+ }, model_filename)
298
+
299
+
300
+ if __name__ == '__main__':
301
+ warnings.filterwarnings("ignore")
302
+ config = get_config()
303
+ train_model(config)