Update utlis.py
Browse files
utlis.py
CHANGED
@@ -127,17 +127,15 @@ def beam_search_decode(model, src, src_mask, max_len, start_symbol, beam_size ,l
|
|
127 |
best_beam = beams[0][0]
|
128 |
return best_beam
|
129 |
|
130 |
-
def translate(model: torch.nn.Module, strategy:str, src_sentence: str, lenght_extend :int =
|
131 |
assert strategy in ['greedy','beam search'], 'the strategy for decoding has to be either greedy or beam search'
|
132 |
-
model.to(device)
|
133 |
-
model.eval()
|
134 |
# Tokenize the source sentence
|
135 |
src = source_tokenizer(src_sentence, **token_config)['input_ids']
|
136 |
num_tokens = src.shape[1]
|
137 |
# Create a source mask
|
138 |
src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
|
139 |
if strategy == 'greedy':
|
140 |
-
tgt_tokens = greedy_decode(model, src, src_mask, max_len=num_tokens +
|
141 |
# Generate the target tokens using beam search decoding
|
142 |
else:
|
143 |
tgt_tokens = beam_search_decode(model, src, src_mask, max_len=num_tokens + lenght_extend, start_symbol=target_tokenizer.bos_token_id, beam_size=beam_size,length_penalty=length_penalty).flatten()
|
@@ -163,4 +161,8 @@ token_config = {
|
|
163 |
"return_tensors": True,
|
164 |
}
|
165 |
|
166 |
-
|
|
|
|
|
|
|
|
|
|
127 |
best_beam = beams[0][0]
|
128 |
return best_beam
|
129 |
|
130 |
+
def translate(model: torch.nn.Module, strategy:str = 'greedy' , src_sentence: str, lenght_extend :int = 5, beam_size: int = 5, length_penalty:float = 0.6):
|
131 |
assert strategy in ['greedy','beam search'], 'the strategy for decoding has to be either greedy or beam search'
|
|
|
|
|
132 |
# Tokenize the source sentence
|
133 |
src = source_tokenizer(src_sentence, **token_config)['input_ids']
|
134 |
num_tokens = src.shape[1]
|
135 |
# Create a source mask
|
136 |
src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
|
137 |
if strategy == 'greedy':
|
138 |
+
tgt_tokens = greedy_decode(model, src, src_mask, max_len=num_tokens + lenght_extend, start_symbol=target_tokenizer.bos_token_id).flatten()
|
139 |
# Generate the target tokens using beam search decoding
|
140 |
else:
|
141 |
tgt_tokens = beam_search_decode(model, src, src_mask, max_len=num_tokens + lenght_extend, start_symbol=target_tokenizer.bos_token_id, beam_size=beam_size,length_penalty=length_penalty).flatten()
|
|
|
161 |
"return_tensors": True,
|
162 |
}
|
163 |
|
164 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
165 |
+
|
166 |
+
model = load_model()
|
167 |
+
model.to(device)
|
168 |
+
model.eval()
|