content_summarizer / summarize.py
KevlarVK's picture
simple code to summarize using bart-large-cnn
a4f4f24
raw
history blame
1.8 kB
from datetime import datetime
from transformers import BartTokenizer, TFBartForConditionalGeneration, pipeline
from Utils import fetch_article_text, count_tokens
import re
from nltk.tokenize import sent_tokenize
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
model = TFBartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
def bart_summarize(text: str):
max_length = model.config.max_position_embeddings
sentences = sent_tokenize(text)
sentences = [sentence for sentence in sentences if len(sentence.strip()) > 0 and len(sentence.split(" ")) > 4]
input_chunks = []
temp_sentences = ""
tokens = 0
for sentence in sentences:
if tokens + count_tokens(sentence) < max_length:
temp_sentences += sentence
tokens += count_tokens(sentence)
else:
input_chunks.append(temp_sentences)
tokens = count_tokens(sentence)
temp_sentences = sentence
if len(temp_sentences) > 0:
input_chunks.append(temp_sentences)
# summarize each input chunk separately
summaries = []
for chunk in input_chunks:
# encode the input chunk
encoded_input = tokenizer.encode(chunk, max_length=max_length, truncation=True, padding='longest', return_tensors='tf')
# generate summary for the input chunk
summary_ids = model.generate(encoded_input, max_length=300, num_beams=4, early_stopping=True)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
# add the summary to the list of summaries
summaries.append(summary)
# # combine the summaries to get the final summary for the entire input
final_summary = " ".join(summaries)
return final_summary