mynkchaudhry commited on
Commit
b9fc60e
·
verified ·
1 Parent(s): a7c7acd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -7
app.py CHANGED
@@ -1,18 +1,23 @@
1
  from transformers import PegasusTokenizer, PegasusForConditionalGeneration
2
- import gradio as gr
3
 
4
- # Load model and tokenizer based on the provided configuration
5
  tokenizer = PegasusTokenizer.from_pretrained("google/pegasus-cnn_dailymail")
6
  model = PegasusForConditionalGeneration.from_pretrained("google/pegasus-cnn_dailymail")
7
 
8
- def summarize(text, prompt):
9
- # Prepend the prompt to the input text
10
- inputs = tokenizer(prompt + " " + text, return_tensors="pt", max_length=1024, truncation=True)
11
- # Generate the summary
12
- summary_ids = model.generate(inputs.input_ids, max_length=128, min_length=32, length_penalty=0.8, num_beams=8, early_stopping=True)
 
 
 
 
 
13
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
14
  return summary
15
 
 
16
  # Gradio interface
17
  interface = gr.Interface(
18
  fn=summarize,
 
1
  from transformers import PegasusTokenizer, PegasusForConditionalGeneration
 
2
 
3
+ # Load tokenizer and model
4
  tokenizer = PegasusTokenizer.from_pretrained("google/pegasus-cnn_dailymail")
5
  model = PegasusForConditionalGeneration.from_pretrained("google/pegasus-cnn_dailymail")
6
 
7
+ def summarize(text):
8
+ inputs = tokenizer(text, return_tensors="pt", max_length=1024, truncation=True)
9
+ summary_ids = model.generate(
10
+ inputs.input_ids,
11
+ max_length=1500,
12
+ min_length=100,
13
+ length_penalty=0.9,
14
+ num_beams=6,
15
+ early_stopping=True
16
+ )
17
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
18
  return summary
19
 
20
+
21
  # Gradio interface
22
  interface = gr.Interface(
23
  fn=summarize,