Spaces:
Sleeping
Sleeping
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
import torch | |
from tqdm.auto import tqdm | |
def handle_long_text( | |
input_text: str, | |
model: AutoModelForSeq2SeqLM, | |
tokenizer: AutoTokenizer, | |
max_length: int = 128, | |
stride: int = 128, | |
batch_length: int = 2048, | |
min_batch_length: int = 512, | |
**generate_kwargs, | |
) -> str: | |
""" | |
Maneja textos largos dividiéndolos en segmentos y generando resúmenes para cada uno. | |
Args: | |
input_text (str): Texto completo a resumir. | |
model: Modelo de resumen abstractivo. | |
tokenizer: Tokenizador asociado al modelo. | |
max_length (int): Longitud máxima del resumen generado por segmento. | |
stride (int): Cantidad de tokens que se superponen entre segmentos. | |
batch_length (int): Longitud máxima de tokens por segmento. | |
min_batch_length (int): Longitud mínima permitida por segmento. | |
generate_kwargs: Parámetros adicionales para el modelo de generación. | |
Returns: | |
str: Resumen final concatenado de todos los segmentos. | |
""" | |
# Validar parámetros de longitud | |
if batch_length < min_batch_length: | |
batch_length = min_batch_length | |
# Tokenizar texto completo en segmentos | |
encoded_input = tokenizer( | |
input_text, | |
return_tensors="pt", | |
max_length=batch_length, | |
truncation=True, | |
stride=stride, | |
return_overflowing_tokens=True, | |
add_special_tokens=True, | |
) | |
# Obtener IDs y máscaras de atención | |
input_ids = encoded_input["input_ids"] | |
attention_masks = encoded_input["attention_mask"] | |
# Progresión para múltiples segmentos | |
summaries = [] | |
pbar = tqdm(total=len(input_ids), desc="Procesando segmentos") | |
for ids, mask in zip(input_ids, attention_masks): | |
# Enviar al dispositivo correcto (CPU/GPU) | |
ids = ids.unsqueeze(0).to(model.device) | |
mask = mask.unsqueeze(0).to(model.device) | |
# Generar resumen para el segmento actual | |
outputs = model.generate( | |
input_ids=ids, | |
attention_mask=mask, | |
max_length=max_length, | |
no_repeat_ngram_size=3, | |
num_beams=4, | |
early_stopping=True, | |
**generate_kwargs, | |
) | |
# Decodificar resumen generado | |
summary = tokenizer.decode( | |
outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True | |
) | |
summaries.append(summary) | |
pbar.update() | |
pbar.close() | |
# Concatenar resúmenes y devolver el texto final | |
final_summary = " ".join(summaries) | |
return final_summary | |