Update app.py
Browse files
app.py
CHANGED
@@ -207,7 +207,7 @@ def beam_search_decode(model, src, src_mask, max_len, start_symbol, beam_size ,l
|
|
207 |
best_beam = beams[0][0]
|
208 |
return best_beam
|
209 |
|
210 |
-
def translate(model: torch.nn.Module, strategy:str = 'greedy' ,
|
211 |
assert strategy in ['greedy','beam search'], 'the strategy for decoding has to be either greedy or beam search'
|
212 |
# Tokenize the source sentence
|
213 |
src = source_tokenizer(src_sentence, **token_config)['input_ids']
|
@@ -249,7 +249,19 @@ model.eval()
|
|
249 |
|
250 |
import gradio as gr
|
251 |
|
252 |
-
|
253 |
-
|
254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
iface.launch()
|
|
|
207 |
best_beam = beams[0][0]
|
208 |
return best_beam
|
209 |
|
210 |
+
def translate(model: torch.nn.Module, src_sentence: str, strategy:str = 'greedy' , lenght_extend :int = 5, beam_size: int = 5, length_penalty:float = 0.6):
|
211 |
assert strategy in ['greedy','beam search'], 'the strategy for decoding has to be either greedy or beam search'
|
212 |
# Tokenize the source sentence
|
213 |
src = source_tokenizer(src_sentence, **token_config)['input_ids']
|
|
|
249 |
|
250 |
import gradio as gr
|
251 |
|
252 |
+
iface = gr.Interface(
|
253 |
+
fn=translate,
|
254 |
+
inputs=[
|
255 |
+
gr.inputs.Textbox("Enter a sentence to translate"),
|
256 |
+
gr.inputs.Radio(['greedy', 'beam search'], label="Decoding Strategy"),
|
257 |
+
gr.inputs.Number(label="Length Extend (for greedy)", default=5),
|
258 |
+
gr.inputs.Number(label="Beam Size (for beam search)", default=5),
|
259 |
+
gr.inputs.Number(label="Length Penalty (for beam search)", default=0.6)
|
260 |
+
],
|
261 |
+
outputs=gr.outputs.Textbox("Translation"),
|
262 |
+
title="Translation Interface",
|
263 |
+
description="Translate text using a pre-trained model.",
|
264 |
+
)
|
265 |
+
|
266 |
+
# Launch the Gradio interface
|
267 |
iface.launch()
|