luminoria's picture
Update app.py
3fb17dc verified
raw
history blame
2.44 kB
from transformers import T5ForConditionalGeneration,T5Tokenizer
from transformers import AutoModelWithLMHead, AutoTokenizer
from transformers import pipeline
import streamlit as st
model = T5ForConditionalGeneration.from_pretrained("Michau/t5-base-en-generate-headline")
tokenizer = T5Tokenizer.from_pretrained("Michau/t5-base-en-generate-headline")
mrm_tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-summarize-news")
mrm_model = AutoModelWithLMHead.from_pretrained("mrm8488/t5-base-finetuned-summarize-news")
def generate_title(article):
text = "headline: " + article
encoding = tokenizer.encode_plus(text, return_tensors = "pt", max_length=2048, truncation=True)
input_ids = encoding["input_ids"]
attention_masks = encoding["attention_mask"]
beam_outputs = model.generate(
input_ids = input_ids,
attention_mask = attention_masks,
max_length = 50,
num_beams = 3,
do_sample = True,
top_k=10,
early_stopping = False,
)
return tokenizer.decode(beam_outputs[0])
# def generate_summary(article):
# input_ids = mrm_tokenizer.encode(article, return_tensors="pt", add_special_tokens=True)
# generated_ids = mrm_model.generate(input_ids=input_ids, num_beams=3, max_length=200, repetition_penalty=2.5, length_penalty=1.0, early_stopping=False, truncation=True)
# preds = [mrm_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]
# return preds[0]
def generate_summary(article):
article = article[:1024]
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
return summarizer(article, max_length=130, min_length=30, do_sample=False)
def main():
st.title("Text Summarization")
text = st.text_area("Enter your text here:", "")
if st.button("Generate Summary"):
if text.strip() == "":
st.error("Please enter some text.")
else:
title = generate_title(text)
summary = generate_summary(text)
# summary = summary[0]['summary_text']
st.subheader("Generated Title:")
st.write(title.replace('<pad>', '').replace('</s>', ''))
st.subheader("Generated Description:")
# st.write(summary.replace('<pad>', '').replace('</s>', ''))
st.write(summary[0]['summary_text'])
if __name__ == "__main__":
main()