content_summarizer / title_generator.py
KevlarVK's picture
Added support for title generation
9a4b6ed
raw
history blame
787 Bytes
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM
class T5Summarizer:
def __init__(self, model_name: str = "fabiochiu/t5-small-medium-title-generation"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
def summarize(self, text: str):
inputs = ["summarize: " + text]
max_input_length = self.tokenizer.model_max_length
inputs = self.tokenizer(inputs, max_length=max_input_length, truncation=True, return_tensors="tf")
output = self.model.generate(**inputs, num_beams=8, do_sample=True, min_length=1, max_length=10, early_stopping=True)
summary = self.tokenizer.batch_decode(output, skip_special_tokens=True)[0]
return summary