Sanj004 commited on
Commit
7ae75ec
·
verified ·
1 Parent(s): b2bb5a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -11,22 +11,22 @@ model = AutoModelForSeq2SeqLM.from_pretrained(SAVED_MODEL_PATH).to(device)
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
 
13
  dataset = load_dataset("samsum")
 
 
14
 
15
  train_data = dataset["train"]
16
  validation_data = dataset["validation"]
17
  test_data = dataset["test"]
18
 
19
- def summarize(tokenizer, model, text):
20
  inputs = tokenizer(f"Summarize dialogue >>\n {text}", return_tensors="pt", max_length=1000, truncation=True, padding="max_length").to(device)
21
  summary_ids = model.generate(inputs.input_ids, num_beams=4, max_length=100, early_stopping=True)
22
  summary = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids]
23
  return summary[0]
24
 
25
- def summarize_dialogue(text):
26
- return summarize(tokenizer, model, text)
27
 
28
  iface = gr.Interface(
29
- fn=summarize_dialogue,
30
  inputs=gr.inputs.Textbox(lines=10, label="Input Dialogue"),
31
  outputs=gr.outputs.Textbox(label="Generated Summary")
32
  )
 
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
 
13
  dataset = load_dataset("samsum")
14
+ #dataset = load_dataset("samsum", download_mode="force_redownload")
15
+
16
 
17
  train_data = dataset["train"]
18
  validation_data = dataset["validation"]
19
  test_data = dataset["test"]
20
 
21
+ def summarize(text):
22
  inputs = tokenizer(f"Summarize dialogue >>\n {text}", return_tensors="pt", max_length=1000, truncation=True, padding="max_length").to(device)
23
  summary_ids = model.generate(inputs.input_ids, num_beams=4, max_length=100, early_stopping=True)
24
  summary = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids]
25
  return summary[0]
26
 
 
 
27
 
28
  iface = gr.Interface(
29
+ fn=summarize,
30
  inputs=gr.inputs.Textbox(lines=10, label="Input Dialogue"),
31
  outputs=gr.outputs.Textbox(label="Generated Summary")
32
  )