luminoria commited on
Commit
346eea8
1 Parent(s): 8cc2067

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import T5ForConditionalGeneration,T5Tokenizer
2
+ from transformers import AutoModelWithLMHead, AutoTokenizer
3
+ from transformers import pipeline
4
+ import streamlit as st
5
+
6
+ model = T5ForConditionalGeneration.from_pretrained("Michau/t5-base-en-generate-headline")
7
+ tokenizer = T5Tokenizer.from_pretrained("Michau/t5-base-en-generate-headline")
8
+
9
+ mrm_tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-summarize-news")
10
+ mrm_model = AutoModelWithLMHead.from_pretrained("mrm8488/t5-base-finetuned-summarize-news")
11
+
12
+
13
+ def generate_title(article):
14
+ text = "headline: " + article
15
+ encoding = tokenizer.encode_plus(text, return_tensors = "pt", max_length=2048, truncation=True)
16
+ input_ids = encoding["input_ids"]
17
+ attention_masks = encoding["attention_mask"]
18
+
19
+ beam_outputs = model.generate(
20
+ input_ids = input_ids,
21
+ attention_mask = attention_masks,
22
+ max_length = 50,
23
+ num_beams = 3,
24
+ do_sample = True,
25
+ top_k=10,
26
+ early_stopping = False,
27
+ )
28
+
29
+ return tokenizer.decode(beam_outputs[0])
30
+
31
+ # def generate_summary(article):
32
+ # input_ids = mrm_tokenizer.encode(article, return_tensors="pt", add_special_tokens=True)
33
+
34
+ # generated_ids = mrm_model.generate(input_ids=input_ids, num_beams=3, max_length=200, repetition_penalty=2.5, length_penalty=1.0, early_stopping=False, truncation=True)
35
+
36
+ # preds = [mrm_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]
37
+
38
+ # return preds[0]
39
+ def generate_summary(article):
40
+ article = article[:1024]
41
+ summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
42
+ return summarizer(article, max_length=130, min_length=30, do_sample=False)
43
+ def main():
44
+ st.title("Text Summarization")
45
+ text = st.text_area("Enter your text here:", "")
46
+
47
+ if st.button("Generate Summary"):
48
+ if text.strip() == "":
49
+ st.error("Please enter some text.")
50
+ else:
51
+ title = generate_title(text)
52
+ summary = generate_summary(text)
53
+ # summary = summary[0]['summary_text']
54
+
55
+ st.subheader("Generated Title:")
56
+ st.write(title.replace('<pad>', '').replace('</s>', ''))
57
+
58
+ st.subheader("Generated Description:")
59
+
60
+ # st.write(summary.replace('<pad>', '').replace('</s>', ''))
61
+ st.write(summary[0]['summary_text'])
62
+
63
+
64
+ if __name__ == "__main__":
65
+ main()