ToS-Summarization / abstractive_model.py
EE21's picture
Update abstractive_model.py
d4396fe
raw
history blame
975 Bytes
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Load the BART tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("EE21/BART-ToSSimplify")
model = AutoModelForSeq2SeqLM.from_pretrained("EE21/BART-ToSSimplify")
# Define a function to summarize text with minimum length constraint
def summarize_with_bart(input_text, max_summary_tokens=200, min_summary_tokens=100, do_sample=False):
# Tokenize the input text and return input_ids as PyTorch tensors
inputs = tokenizer(input_text, return_tensors="pt").input_ids
# Generate the summary with minimum and maximum length constraints
outputs = model.generate(inputs,
max_length=max_summary_tokens,
min_length=min_summary_tokens,
do_sample=do_sample)
# Decode the generated token IDs back into text
summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
return summary