jla25 commited on
Commit
05da434
verified
1 Parent(s): a87db35

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +17 -10
handler.py CHANGED
@@ -1,6 +1,5 @@
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
  import torch
3
- import json
4
 
5
  class EndpointHandler:
6
  def __init__(self, model_dir):
@@ -9,21 +8,29 @@ class EndpointHandler:
9
  self.model.eval()
10
 
11
  def preprocess(self, data):
12
- # Validar entrada
13
- if not data or not isinstance(data, dict) or "inputs" not in data or data["inputs"] is None:
14
- raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor v谩lido")
15
-
16
  input_text = "Generate a valid JSON capturing data from this text: " + data["inputs"]
17
- if not input_text.strip():
18
- raise ValueError("El texto de entrada no puede estar vac铆o")
19
-
20
- # Tokenizaci贸n de la entrada
21
  tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
 
 
22
  return tokens
23
 
24
  def inference(self, tokens):
 
 
 
 
 
 
 
 
 
25
  with torch.no_grad():
26
- outputs = self.model.generate(**tokens)
 
 
27
  return outputs
28
 
29
  def postprocess(self, outputs):
 
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
  import torch
 
3
 
4
  class EndpointHandler:
5
  def __init__(self, model_dir):
 
8
  self.model.eval()
9
 
10
  def preprocess(self, data):
11
+ if not isinstance(data, dict) or "inputs" not in data or data["inputs"] is None:
12
+ raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor v谩lido.")
13
+
 
14
  input_text = "Generate a valid JSON capturing data from this text: " + data["inputs"]
 
 
 
 
15
  tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
16
+ if not tokens or not tokens["input_ids"]:
17
+ raise ValueError("Error al tokenizar el texto de entrada. Verifica el texto.")
18
  return tokens
19
 
20
  def inference(self, tokens):
21
+ generate_kwargs = {
22
+ "max_length": 1000,
23
+ "num_beams": 5,
24
+ "do_sample": True,
25
+ "temperature": 0.3,
26
+ "top_k": 50,
27
+ "top_p": 0.9,
28
+ "repetition_penalty": 2.5
29
+ }
30
  with torch.no_grad():
31
+ outputs = self.model.generate(**tokens, **generate_kwargs)
32
+ if outputs is None or len(outputs) == 0:
33
+ raise ValueError("El modelo no gener贸 ninguna salida.")
34
  return outputs
35
 
36
  def postprocess(self, outputs):