Sifal commited on
Commit
2ce4922
1 Parent(s): 9065f39

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -4
app.py CHANGED
@@ -207,7 +207,7 @@ def beam_search_decode(model, src, src_mask, max_len, start_symbol, beam_size ,l
207
  best_beam = beams[0][0]
208
  return best_beam
209
 
210
- def translate(model: torch.nn.Module, strategy:str = 'greedy' , src_sentence: str, lenght_extend :int = 5, beam_size: int = 5, length_penalty:float = 0.6):
211
  assert strategy in ['greedy','beam search'], 'the strategy for decoding has to be either greedy or beam search'
212
  # Tokenize the source sentence
213
  src = source_tokenizer(src_sentence, **token_config)['input_ids']
@@ -249,7 +249,19 @@ model.eval()
249
 
250
  import gradio as gr
251
 
252
- x = lambda text : translate(x)
253
-
254
- iface = gr.Interface(fn=x, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
255
  iface.launch()
 
207
  best_beam = beams[0][0]
208
  return best_beam
209
 
210
+ def translate(model: torch.nn.Module, src_sentence: str, strategy:str = 'greedy' , lenght_extend :int = 5, beam_size: int = 5, length_penalty:float = 0.6):
211
  assert strategy in ['greedy','beam search'], 'the strategy for decoding has to be either greedy or beam search'
212
  # Tokenize the source sentence
213
  src = source_tokenizer(src_sentence, **token_config)['input_ids']
 
249
 
250
  import gradio as gr
251
 
252
+ iface = gr.Interface(
253
+ fn=translate,
254
+ inputs=[
255
+ gr.inputs.Textbox("Enter a sentence to translate"),
256
+ gr.inputs.Radio(['greedy', 'beam search'], label="Decoding Strategy"),
257
+ gr.inputs.Number(label="Length Extend (for greedy)", default=5),
258
+ gr.inputs.Number(label="Beam Size (for beam search)", default=5),
259
+ gr.inputs.Number(label="Length Penalty (for beam search)", default=0.6)
260
+ ],
261
+ outputs=gr.outputs.Textbox("Translation"),
262
+ title="Translation Interface",
263
+ description="Translate text using a pre-trained model.",
264
+ )
265
+
266
+ # Launch the Gradio interface
267
  iface.launch()