Spaces:
Sleeping
Sleeping
File size: 2,638 Bytes
007521f 56da2e5 007521f 56da2e5 007521f 56da2e5 007521f 56da2e5 007521f 56da2e5 007521f 56da2e5 007521f 56da2e5 007521f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
from tqdm.auto import tqdm
def handle_long_text(
input_text: str,
model: AutoModelForSeq2SeqLM,
tokenizer: AutoTokenizer,
max_length: int = 128,
stride: int = 128,
batch_length: int = 2048,
min_batch_length: int = 512,
**generate_kwargs,
) -> str:
"""
Maneja textos largos dividiéndolos en segmentos y generando resúmenes para cada uno.
Args:
input_text (str): Texto completo a resumir.
model: Modelo de resumen abstractivo.
tokenizer: Tokenizador asociado al modelo.
max_length (int): Longitud máxima del resumen generado por segmento.
stride (int): Cantidad de tokens que se superponen entre segmentos.
batch_length (int): Longitud máxima de tokens por segmento.
min_batch_length (int): Longitud mínima permitida por segmento.
generate_kwargs: Parámetros adicionales para el modelo de generación.
Returns:
str: Resumen final concatenado de todos los segmentos.
"""
# Validar parámetros de longitud
if batch_length < min_batch_length:
batch_length = min_batch_length
# Tokenizar texto completo en segmentos
encoded_input = tokenizer(
input_text,
return_tensors="pt",
max_length=batch_length,
truncation=True,
stride=stride,
return_overflowing_tokens=True,
add_special_tokens=True,
)
# Obtener IDs y máscaras de atención
input_ids = encoded_input["input_ids"]
attention_masks = encoded_input["attention_mask"]
# Progresión para múltiples segmentos
summaries = []
pbar = tqdm(total=len(input_ids), desc="Procesando segmentos")
for ids, mask in zip(input_ids, attention_masks):
# Enviar al dispositivo correcto (CPU/GPU)
ids = ids.unsqueeze(0).to(model.device)
mask = mask.unsqueeze(0).to(model.device)
# Generar resumen para el segmento actual
outputs = model.generate(
input_ids=ids,
attention_mask=mask,
max_length=max_length,
no_repeat_ngram_size=3,
num_beams=4,
early_stopping=True,
**generate_kwargs,
)
# Decodificar resumen generado
summary = tokenizer.decode(
outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True
)
summaries.append(summary)
pbar.update()
pbar.close()
# Concatenar resúmenes y devolver el texto final
final_summary = " ".join(summaries)
return final_summary
|