luminoria's picture
Update app.py
bc70267 verified
raw
history blame
2.95 kB
from transformers import T5ForConditionalGeneration,T5Tokenizer
from transformers import AutoModelWithLMHead, AutoTokenizer
from transformers import pipeline
import streamlit as st
import re
model = T5ForConditionalGeneration.from_pretrained("Michau/t5-base-en-generate-headline")
tokenizer = T5Tokenizer.from_pretrained("Michau/t5-base-en-generate-headline")
mrm_tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-summarize-news")
mrm_model = AutoModelWithLMHead.from_pretrained("mrm8488/t5-base-finetuned-summarize-news")
jules_tokenizer = AutoTokenizer.from_pretrained("JulesBelveze/t5-small-headline-generator")
jules_model = T5ForConditionalGeneration.from_pretrained("JulesBelveze/t5-small-headline-generator")
# rouge = Rouge()
WHITESPACE_HANDLER = lambda k: re.sub('\s+', ' ', re.sub('\n+', ' ', k.strip()))
def generate_title(article):
text = "headline: " + article
encoding = tokenizer.encode_plus(text, return_tensors = "pt", max_length=2048, truncation=True)
input_ids = encoding["input_ids"]
attention_masks = encoding["attention_mask"]
beam_outputs = model.generate(
input_ids = input_ids,
attention_mask = attention_masks,
max_length = 50,
num_beams = 3,
do_sample = False,
# top_k=10,
early_stopping = False,
)
return tokenizer.decode(beam_outputs[0])
def generate_title_2(article):
input_ids = tokenizer(
[WHITESPACE_HANDLER(article)],
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=384
)["input_ids"]
output_ids = model.generate(
input_ids=input_ids,
max_length=84,
no_repeat_ngram_size=2,
num_beams=4
)[0]
summary = tokenizer.decode(
output_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
return summary
def generate_summary(article):
article = article[:1024]
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
return summarizer(article, max_length=130, min_length=30, do_sample=False)
def main():
st.title("Text Summarization")
text = st.text_area("Enter your text here:", "")
if st.button("Generate Summary"):
if text.strip() == "":
st.error("Please enter some text.")
else:
title = generate_title(text)
title_2 = generate_title_2(text)
summary = generate_summary(text)
# summary = summary[0]['summary_text']
st.subheader("Generated Title:")
st.write(title.replace('<pad>', '').replace('</s>', ''))
st.subheader("Second Title:")
st.write(title_2)
st.subheader("Generated Description:")
# st.write(summary.replace('<pad>', '').replace('</s>', ''))
st.write(summary[0]['summary_text'])
if __name__ == "__main__":
main()