Overglitch's picture
first commit
56da2e5
raw
history blame
711 Bytes
def handle_long_text(text, model, tokenizer, max_length=2048, stride=128):
encoded_input = tokenizer(
text,
max_length=max_length,
stride=stride,
truncation=True,
return_overflowing_tokens=True,
return_tensors="pt",
)
summaries = []
for input_ids, attention_mask in zip(
encoded_input.input_ids, encoded_input.attention_mask
):
output = model.generate(
input_ids.to(model.device),
attention_mask=attention_mask.to(model.device),
max_length=128,
num_beams=4,
)
summaries.append(tokenizer.decode(output[0], skip_special_tokens=True))
return " ".join(summaries)