File size: 711 Bytes
56da2e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def handle_long_text(text, model, tokenizer, max_length=2048, stride=128):
    encoded_input = tokenizer(
        text,
        max_length=max_length,
        stride=stride,
        truncation=True,
        return_overflowing_tokens=True,
        return_tensors="pt",
    )
    summaries = []
    for input_ids, attention_mask in zip(
        encoded_input.input_ids, encoded_input.attention_mask
    ):
        output = model.generate(
            input_ids.to(model.device),
            attention_mask=attention_mask.to(model.device),
            max_length=128,
            num_beams=4,
        )
        summaries.append(tokenizer.decode(output[0], skip_special_tokens=True))
    return " ".join(summaries)