Overglitch commited on
Commit
007521f
·
verified ·
1 Parent(s): 3b5c554

Update modules/utils.py

Browse files
Files changed (1) hide show
  1. modules/utils.py +74 -14
modules/utils.py CHANGED
@@ -1,21 +1,81 @@
1
- def handle_long_text(text, model, tokenizer, max_length=2048, stride=128):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  encoded_input = tokenizer(
3
- text,
4
- max_length=max_length,
5
- stride=stride,
6
  truncation=True,
 
7
  return_overflowing_tokens=True,
8
- return_tensors="pt",
9
  )
 
 
 
 
 
 
10
  summaries = []
11
- for input_ids, attention_mask in zip(
12
- encoded_input.input_ids, encoded_input.attention_mask
13
- ):
14
- output = model.generate(
15
- input_ids.to(model.device),
16
- attention_mask=attention_mask.to(model.device),
17
- max_length=128,
 
 
 
 
 
 
18
  num_beams=4,
 
 
 
 
 
 
19
  )
20
- summaries.append(tokenizer.decode(output[0], skip_special_tokens=True))
21
- return " ".join(summaries)
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
2
+ import torch
3
+ from tqdm.auto import tqdm
4
+
5
+
6
+ def handle_long_text(
7
+ input_text: str,
8
+ model: AutoModelForSeq2SeqLM,
9
+ tokenizer: AutoTokenizer,
10
+ max_length: int = 128,
11
+ stride: int = 128,
12
+ batch_length: int = 2048,
13
+ min_batch_length: int = 512,
14
+ **generate_kwargs,
15
+ ) -> str:
16
+ """
17
+ Maneja textos largos dividiéndolos en segmentos y generando resúmenes para cada uno.
18
+
19
+ Args:
20
+ input_text (str): Texto completo a resumir.
21
+ model: Modelo de resumen abstractivo.
22
+ tokenizer: Tokenizador asociado al modelo.
23
+ max_length (int): Longitud máxima del resumen generado por segmento.
24
+ stride (int): Cantidad de tokens que se superponen entre segmentos.
25
+ batch_length (int): Longitud máxima de tokens por segmento.
26
+ min_batch_length (int): Longitud mínima permitida por segmento.
27
+ generate_kwargs: Parámetros adicionales para el modelo de generación.
28
+
29
+ Returns:
30
+ str: Resumen final concatenado de todos los segmentos.
31
+ """
32
+ # Validar parámetros de longitud
33
+ if batch_length < min_batch_length:
34
+ batch_length = min_batch_length
35
+
36
+ # Tokenizar texto completo en segmentos
37
  encoded_input = tokenizer(
38
+ input_text,
39
+ return_tensors="pt",
40
+ max_length=batch_length,
41
  truncation=True,
42
+ stride=stride,
43
  return_overflowing_tokens=True,
44
+ add_special_tokens=True,
45
  )
46
+
47
+ # Obtener IDs y máscaras de atención
48
+ input_ids = encoded_input["input_ids"]
49
+ attention_masks = encoded_input["attention_mask"]
50
+
51
+ # Progresión para múltiples segmentos
52
  summaries = []
53
+ pbar = tqdm(total=len(input_ids), desc="Procesando segmentos")
54
+
55
+ for ids, mask in zip(input_ids, attention_masks):
56
+ # Enviar al dispositivo correcto (CPU/GPU)
57
+ ids = ids.unsqueeze(0).to(model.device)
58
+ mask = mask.unsqueeze(0).to(model.device)
59
+
60
+ # Generar resumen para el segmento actual
61
+ outputs = model.generate(
62
+ input_ids=ids,
63
+ attention_mask=mask,
64
+ max_length=max_length,
65
+ no_repeat_ngram_size=3,
66
  num_beams=4,
67
+ early_stopping=True,
68
+ **generate_kwargs,
69
+ )
70
+ # Decodificar resumen generado
71
+ summary = tokenizer.decode(
72
+ outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True
73
  )
74
+ summaries.append(summary)
75
+ pbar.update()
76
+
77
+ pbar.close()
78
+
79
+ # Concatenar resúmenes y devolver el texto final
80
+ final_summary = " ".join(summaries)
81
+ return final_summary