File size: 6,832 Bytes
e6827ec
fb8d450
 
 
 
 
e6827ec
fb8d450
 
 
 
 
 
 
e6827ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b177f2
e6827ec
 
 
 
 
 
 
4b177f2
e6827ec
 
 
 
fb8d450
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b177f2
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import yaml 
import torch
from .model import Seq2SeqTransformer
from transformers import AutoTokenizer
from transformers import PreTrainedTokenizerFast
from tokenizers.processors import TemplateProcessing


def addPreprocessing(tokenizer):
      tokenizer._tokenizer.post_processor = TemplateProcessing(
          single=tokenizer.bos_token + " $A " + tokenizer.eos_token,
          special_tokens=[(tokenizer.eos_token, tokenizer.eos_token_id), (tokenizer.bos_token, tokenizer.bos_token_id)])

def load_model(model_checkpoint_dir='model.pt',config_dir='config.yaml'):
    
    with open(config_dir, 'r') as yaml_file:
        loaded_model_params = yaml.safe_load(yaml_file)
        
    # Create a new instance of the model with the loaded configuration
    model = Seq2SeqTransformer(
        loaded_model_params["num_encoder_layers"],
        loaded_model_params["num_decoder_layers"],
        loaded_model_params["emb_size"],
        loaded_model_params["nhead"],
        loaded_model_params["source_vocab_size"],
        loaded_model_params["target_vocab_size"],
        loaded_model_params["ffn_hid_dim"]
    )
    
    checkpoint = torch.load(model_checkpoint_dir) if torch.cuda.is_available() else torch.load(model_checkpoint_dir,map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint)
    
    return model


def greedy_decode(model, src, src_mask, max_len, start_symbol):
    # Move inputs to the device
    src = src.to(device)
    src_mask = src_mask.to(device)

    # Encode the source sequence
    memory = model.encode(src, src_mask)

    # Initialize the target sequence with the start symbol
    ys = torch.tensor([[start_symbol]]).type(torch.long).to(device)

    for i in range(max_len - 1):
        memory = memory.to(device)
        # Create a target mask for autoregressive decoding
        tgt_mask = torch.tril(torch.full((ys.size(1), ys.size(1)), float('-inf'), device=device), diagonal=-1).transpose(0, 1).to(device)
        # Decode the target sequence
        out = model.decode(ys, memory, tgt_mask)
        # Generate the probability distribution over the vocabulary
        prob = model.generator(out[:, -1])

        # Select the next word with the highest probability
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        # Append the next word to the target sequence
        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)

        # Check if the generated word is the end-of-sequence token
        if next_word == target_tokenizer.eos_token_id:
            break

    return ys


def beam_search_decode(model, src, src_mask, max_len, start_symbol, beam_size ,length_penalty):
    # Move inputs to the device
    src = src.to(device)
    src_mask = src_mask.to(device)

    # Encode the source sequence
    memory = model.encode(src, src_mask) # b * seqlen_src * hdim

    # Initialize the beams (sequences, score)
    beams = [(torch.tensor([[start_symbol]]).type(torch.long).to(device), 0)] 

    for i in range(max_len - 1):
        new_beams = []
        complete_beams = []
        cbl = []

        for ys, score in beams:

            # Create a target mask for autoregressive decoding
            tgt_mask = torch.tril(torch.full((ys.size(1), ys.size(1)), float('-inf'), device=device), diagonal=-1).transpose(0, 1).to(device)
            # Decode the target sequence
            out = model.decode(ys, memory, tgt_mask) # b * seqlen_tgt * hdim
            #print(f'shape out {out.shape}')
            # Generate the probability distribution over the vocabulary
            prob = model.generator(out[:, -1]) # b * tgt_vocab_size
            #print(f'shape prob {prob.shape}')

            # Get the top beam_size candidates for the next word
            _, top_indices = torch.topk(prob, beam_size, dim=1) # b * beam_size
            
            for j,next_word in enumerate(top_indices[0]):

                next_word = next_word.item()
                
                # Append the next word to the target sequence
                new_ys = torch.cat([ys, torch.full((1, 1), fill_value=next_word, dtype=src.dtype).to(device)], dim=1)
                
                length_factor = (5 + j / 6) ** length_penalty
                new_score = (score + prob[0][next_word].item()) / length_factor
                
                if next_word == target_tokenizer.eos_token_id:
                    complete_beams.append((new_ys, new_score))
                else:
                    new_beams.append((new_ys, new_score))
            
        
        # Sort the beams by score and select the top beam_size beams
        new_beams.sort(key=lambda x: x[1], reverse=True)
        try:
            beams = new_beams[:beam_size]
        except:
            beams = new_beams

    beams = new_beams + complete_beams
    beams.sort(key=lambda x: x[1], reverse=True)
    
    best_beam = beams[0][0]
    return best_beam

def translate(model: torch.nn.Module, strategy:str = 'greedy' , src_sentence: str, lenght_extend :int = 5, beam_size: int = 5,  length_penalty:float = 0.6):
    assert strategy in ['greedy','beam search'], 'the strategy for decoding has to be either greedy or beam search'
    # Tokenize the source sentence
    src = source_tokenizer(src_sentence, **token_config)['input_ids']
    num_tokens = src.shape[1]
    # Create a source mask
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    if strategy == 'greedy':
       tgt_tokens = greedy_decode(model, src, src_mask, max_len=num_tokens + lenght_extend, start_symbol=target_tokenizer.bos_token_id).flatten()
    # Generate the target tokens using beam search decoding
    else:
        tgt_tokens = beam_search_decode(model, src, src_mask, max_len=num_tokens + lenght_extend, start_symbol=target_tokenizer.bos_token_id, beam_size=beam_size,length_penalty=length_penalty).flatten()
    # Decode the target tokens and clean up the result
    return target_tokenizer.decode(tgt_tokens, clean_up_tokenization_spaces=True, skip_special_tokens=True)

special_tokens = {'unk_token':"[UNK]",
                  'cls_token':"[CLS]",
                  'eos_token': '[EOS]',
                  'sep_token':"[SEP]",
                  'pad_token':"[PAD]",
                  'mask_token':"[MASK]",
                  'bos_token':"[BOS]"}

source_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", **special_tokens)
target_tokenizer = PreTrainedTokenizerFast.from_pretrained('Sifal/E2KT')

addPreprocessing(source_tokenizer)
addPreprocessing(target_tokenizer)

token_config = {
                "add_special_tokens": True,
                "return_tensors": True,
             }

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = load_model()
model.to(device)
model.eval()