jla25 commited on
Commit
a87db35
verified
1 Parent(s): 0dd4228

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +9 -11
handler.py CHANGED
@@ -4,36 +4,34 @@ import json
4
 
5
  class EndpointHandler:
6
  def __init__(self, model_dir):
7
- # Cargar el modelo y el tokenizador desde el directorio del modelo
8
  self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
9
  self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
10
- self.model.eval() # Configurar el modelo en modo de evaluaci贸n
11
 
12
  def preprocess(self, data):
13
- # Preprocesamiento de la entrada
14
- if isinstance(data, dict) and "inputs" in data:
15
- input_text = "Generate a valid JSON capturing data from this text: " + data["inputs"]
16
- else:
17
- raise ValueError("Esperando un diccionario con la clave 'inputs'")
 
 
18
 
19
  # Tokenizaci贸n de la entrada
20
  tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
21
  return tokens
22
 
23
  def inference(self, tokens):
24
- # Realizar la inferencia
25
  with torch.no_grad():
26
  outputs = self.model.generate(**tokens)
27
  return outputs
28
 
29
  def postprocess(self, outputs):
30
- # Decodificar la salida del modelo
31
  decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
32
  return {"generated_text": decoded_output}
33
 
34
  def __call__(self, data):
35
- # Llamada principal del handler para procesamiento completo
36
  tokens = self.preprocess(data)
37
  outputs = self.inference(tokens)
38
  result = self.postprocess(outputs)
39
- return result
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, model_dir):
 
7
  self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
8
  self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
9
+ self.model.eval()
10
 
11
  def preprocess(self, data):
12
+ # Validar entrada
13
+ if not data or not isinstance(data, dict) or "inputs" not in data or data["inputs"] is None:
14
+ raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor v谩lido")
15
+
16
+ input_text = "Generate a valid JSON capturing data from this text: " + data["inputs"]
17
+ if not input_text.strip():
18
+ raise ValueError("El texto de entrada no puede estar vac铆o")
19
 
20
  # Tokenizaci贸n de la entrada
21
  tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
22
  return tokens
23
 
24
  def inference(self, tokens):
 
25
  with torch.no_grad():
26
  outputs = self.model.generate(**tokens)
27
  return outputs
28
 
29
  def postprocess(self, outputs):
 
30
  decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
31
  return {"generated_text": decoded_output}
32
 
33
  def __call__(self, data):
 
34
  tokens = self.preprocess(data)
35
  outputs = self.inference(tokens)
36
  result = self.postprocess(outputs)
37
+ return result