jla25 commited on
Commit
032fa14
verified
1 Parent(s): 11fa52a

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +37 -4
handler.py CHANGED
@@ -1,7 +1,40 @@
1
- class EndpointHandler:
 
 
 
 
2
  def __init__(self, model_dir):
3
- print("Inicializando el modelo")
4
- self.model_dir = model_dir
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  def __call__(self, data):
7
- return {"message": "Este es un mensaje de prueba"}
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
+ import torch
3
+ import json
4
+
5
+ class EndpointHandler: # Aseg煤rate de que el nombre sea EndpointHandler
6
  def __init__(self, model_dir):
7
+ # Cargar el modelo y el tokenizador desde el directorio del modelo
8
+ self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
9
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
10
+ self.model.eval() # Configurar el modelo en modo de evaluaci贸n
11
+
12
+ def preprocess(self, data):
13
+ # Preprocesamiento de la entrada
14
+ if isinstance(data, dict) and "input_text" in data:
15
+ input_text = data["input_text"]
16
+ else:
17
+ raise ValueError("Esperando un diccionario con la clave 'input_text'")
18
+
19
+ # Tokenizaci贸n de la entrada
20
+ tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
21
+ return tokens
22
+
23
+ def inference(self, tokens):
24
+ # Realizar la inferencia
25
+ with torch.no_grad():
26
+ outputs = self.model.generate(**tokens)
27
+ return outputs
28
+
29
+ def postprocess(self, outputs):
30
+ # Decodificar la salida del modelo
31
+ decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
32
+ return {"generated_text": decoded_output}
33
 
34
  def __call__(self, data):
35
+ # Llamada principal del handler para procesamiento completo
36
+ tokens = self.preprocess(data)
37
+ outputs = self.inference(tokens)
38
+ result = self.postprocess(outputs)
39
+ return result
40
+