File size: 11,003 Bytes
9065f39
 
 
 
4d8256e
9065f39
 
 
 
 
 
 
4d8256e
9065f39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9607e93
9065f39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6aa14e
 
2ce4922
 
 
a4cfc97
 
3621cfc
 
 
2ce4922
4020791
9607e93
2ce4922
 
 
 
d6aa14e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import Transformer
import math
# helper Module that adds positional encoding to the token embedding to introduce a notion of word order.
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

# helper Module to convert tensor of input indices into corresponding tensor of token embeddings
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout,
                                       batch_first=True)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(
            emb_size, dropout=dropout)

    def forward(self,
                src: Tensor,
                trg: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)

import yaml 
from transformers import AutoTokenizer
from transformers import PreTrainedTokenizerFast
from tokenizers.processors import TemplateProcessing


def addPreprocessing(tokenizer):
      tokenizer._tokenizer.post_processor = TemplateProcessing(
          single=tokenizer.bos_token + " $A " + tokenizer.eos_token,
          special_tokens=[(tokenizer.eos_token, tokenizer.eos_token_id), (tokenizer.bos_token, tokenizer.bos_token_id)])

def load_model(model_checkpoint_dir='model.pt',config_dir='config.yaml'):
    
    with open(config_dir, 'r') as yaml_file:
        loaded_model_params = yaml.safe_load(yaml_file)
        
    # Create a new instance of the model with the loaded configuration
    model = Seq2SeqTransformer(
        loaded_model_params["num_encoder_layers"],
        loaded_model_params["num_decoder_layers"],
        loaded_model_params["emb_size"],
        loaded_model_params["nhead"],
        loaded_model_params["source_vocab_size"],
        loaded_model_params["target_vocab_size"],
        loaded_model_params["ffn_hid_dim"]
    )
    
    checkpoint = torch.load(model_checkpoint_dir) if torch.cuda.is_available() else torch.load(model_checkpoint_dir,map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint)
    
    return model


def greedy_decode(model, src, src_mask, max_len, start_symbol):
    # Move inputs to the device
    src = src.to(device)
    src_mask = src_mask.to(device)

    # Encode the source sequence
    memory = model.encode(src, src_mask)

    # Initialize the target sequence with the start symbol
    ys = torch.tensor([[start_symbol]]).type(torch.long).to(device)

    for i in range(max_len - 1):
        memory = memory.to(device)
        # Create a target mask for autoregressive decoding
        tgt_mask = torch.tril(torch.full((ys.size(1), ys.size(1)), float('-inf'), device=device), diagonal=-1).transpose(0, 1).to(device)
        # Decode the target sequence
        out = model.decode(ys, memory, tgt_mask)
        # Generate the probability distribution over the vocabulary
        prob = model.generator(out[:, -1])

        # Select the next word with the highest probability
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        # Append the next word to the target sequence
        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)

        # Check if the generated word is the end-of-sequence token
        if next_word == target_tokenizer.eos_token_id:
            break

    return ys


def beam_search_decode(model, src, src_mask, max_len, start_symbol, beam_size ,length_penalty):
    # Move inputs to the device
    src = src.to(device)
    src_mask = src_mask.to(device)

    # Encode the source sequence
    memory = model.encode(src, src_mask) # b * seqlen_src * hdim

    # Initialize the beams (sequences, score)
    beams = [(torch.tensor([[start_symbol]]).type(torch.long).to(device), 0)] 

    for i in range(max_len - 1):
        new_beams = []
        complete_beams = []
        cbl = []

        for ys, score in beams:

            # Create a target mask for autoregressive decoding
            tgt_mask = torch.tril(torch.full((ys.size(1), ys.size(1)), float('-inf'), device=device), diagonal=-1).transpose(0, 1).to(device)
            # Decode the target sequence
            out = model.decode(ys, memory, tgt_mask) # b * seqlen_tgt * hdim
            #print(f'shape out {out.shape}')
            # Generate the probability distribution over the vocabulary
            prob = model.generator(out[:, -1]) # b * tgt_vocab_size
            #print(f'shape prob {prob.shape}')

            # Get the top beam_size candidates for the next word
            _, top_indices = torch.topk(prob, beam_size, dim=1) # b * beam_size
            
            for j,next_word in enumerate(top_indices[0]):

                next_word = next_word.item()
                
                # Append the next word to the target sequence
                new_ys = torch.cat([ys, torch.full((1, 1), fill_value=next_word, dtype=src.dtype).to(device)], dim=1)
                
                length_factor = (5 + j / 6) ** length_penalty
                new_score = (score + prob[0][next_word].item()) / length_factor
                
                if next_word == target_tokenizer.eos_token_id:
                    complete_beams.append((new_ys, new_score))
                else:
                    new_beams.append((new_ys, new_score))
            
        
        # Sort the beams by score and select the top beam_size beams
        new_beams.sort(key=lambda x: x[1], reverse=True)
        try:
            beams = new_beams[:beam_size]
        except:
            beams = new_beams

    beams = new_beams + complete_beams
    beams.sort(key=lambda x: x[1], reverse=True)
    
    best_beam = beams[0][0]
    return best_beam

def translate(src_sentence: str, strategy:str = 'greedy' , lenght_extend :int = 5, beam_size: int = 5,  length_penalty:float = 0.6):
    assert strategy in ['greedy','beam search'], 'the strategy for decoding has to be either greedy or beam search'
    # Tokenize the source sentence
    src = source_tokenizer(src_sentence, **token_config)['input_ids']
    num_tokens = src.shape[1]
    # Create a source mask
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    if strategy == 'greedy':
       tgt_tokens = greedy_decode(model, src, src_mask, max_len=num_tokens + lenght_extend, start_symbol=target_tokenizer.bos_token_id).flatten()
    # Generate the target tokens using beam search decoding
    else:
        tgt_tokens = beam_search_decode(model, src, src_mask, max_len=num_tokens + lenght_extend, start_symbol=target_tokenizer.bos_token_id, beam_size=beam_size,length_penalty=length_penalty).flatten()
    # Decode the target tokens and clean up the result
    return target_tokenizer.decode(tgt_tokens, clean_up_tokenization_spaces=True, skip_special_tokens=True)

special_tokens = {'unk_token':"[UNK]",
                  'cls_token':"[CLS]",
                  'eos_token': '[EOS]',
                  'sep_token':"[SEP]",
                  'pad_token':"[PAD]",
                  'mask_token':"[MASK]",
                  'bos_token':"[BOS]"}

source_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", **special_tokens)
target_tokenizer = PreTrainedTokenizerFast.from_pretrained('Sifal/E2KT')

addPreprocessing(source_tokenizer)
addPreprocessing(target_tokenizer)

token_config = {
                "add_special_tokens": True,
                "return_tensors": True,
             }

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = load_model()
model.to(device)
model.eval()

import gradio as gr

iface = gr.Interface(
    fn=translate,
    inputs=[
        gr.Textbox("Enter a sentence to translate"),
        gr.Radio(['greedy', 'beam search'], label="Decoding Strategy"),
        gr.Number(label="Length Extend (for greedy)"),
        gr.Number(label="Beam Size (for beam search)"),
        gr.Number(label="Length Penalty (for beam search)")
    ],
    outputs=gr.Textbox(label="Translation"),
    title="Translation Interface for English to Kabyle",
    description="Translate text using a pre-trained model.",
)

# Launch the Gradio interface
iface.launch()