mrm8488 commited on
Commit
9c9a793
·
1 Parent(s): 495c4a1
Files changed (1) hide show
  1. app.py +2 -3
app.py CHANGED
@@ -14,8 +14,7 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
14
 
15
 
16
  def summarize(lang, text):
17
- tokenizer = RobertaTokenizerFast.from_pretrained(models_paths[lang]) if lang in [
18
- "fr", "es"] else BertTokenizerFast.from_pretrained(models_paths[lang])
19
  model = EncoderDecoderModel.from_pretrained(models_paths[lang]).to(device)
20
  inputs = tokenizer([text], padding="max_length",
21
  truncation=True, max_length=512, return_tensors="pt")
@@ -25,5 +24,5 @@ def summarize(lang, text):
25
  return tokenizer.decode(output[0], skip_special_tokens=True)
26
 
27
 
28
- gr.Interface(fn=summarize, inputs=[gr.inputs.CheckboxGroup(["fr", "de", "tu", "es"]), gr.inputs.Textbox(
29
  lines=7, label="Input Text")], outputs="text").launch(inline=False)
 
14
 
15
 
16
  def summarize(lang, text):
17
+ tokenizer = RobertaTokenizerFast.from_pretrained(models_paths[lang]) if lang == "fr" or lang == "es" else BertTokenizerFast.from_pretrained(models_paths[lang])
 
18
  model = EncoderDecoderModel.from_pretrained(models_paths[lang]).to(device)
19
  inputs = tokenizer([text], padding="max_length",
20
  truncation=True, max_length=512, return_tensors="pt")
 
24
  return tokenizer.decode(output[0], skip_special_tokens=True)
25
 
26
 
27
+ gr.Interface(fn=summarize, inputs=[gr.inputs.Radio(["fr", "de", "tu", "es"]), gr.inputs.Textbox(
28
  lines=7, label="Input Text")], outputs="text").launch(inline=False)