jla25 commited on
Commit
0d16e19
verified
1 Parent(s): 7504697

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +39 -0
handler.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
+ import torch
3
+ import json
4
+
5
+ class EndpointHandler:
6
+ def __init__(self, model_dir):
7
+ # Cargar el modelo y el tokenizador desde el directorio del modelo
8
+ self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
9
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
10
+ self.model.eval() # Configurar el modelo en modo de evaluaci贸n
11
+
12
+ def preprocess(self, data):
13
+ # Preprocesamiento de la entrada
14
+ if isinstance(data, dict) and "inputs" in data:
15
+ input_text = "Generate a valid JSON capturing data from this text: " + data["inputs"]
16
+ else:
17
+ raise ValueError("Esperando un diccionario con la clave 'inputs'")
18
+
19
+ # Tokenizaci贸n de la entrada
20
+ tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
21
+ return tokens
22
+
23
+ def inference(self, tokens):
24
+ # Realizar la inferencia
25
+ with torch.no_grad():
26
+ outputs = self.model.generate(**tokens)
27
+ return outputs
28
+
29
+ def postprocess(self, outputs):
30
+ # Decodificar la salida del modelo
31
+ decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
32
+ return {"generated_text": decoded_output}
33
+
34
+ def __call__(self, data):
35
+ # Llamada principal del handler para procesamiento completo
36
+ tokens = self.preprocess(data)
37
+ outputs = self.inference(tokens)
38
+ result = self.postprocess(outputs)
39
+ return result