jla25 commited on
Commit
258e3f1
verified
1 Parent(s): b9ba426

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +40 -0
handler.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
+ import torch
3
+
4
+ class ModelHandler:
5
+ def __init__(self, model_dir):
6
+ # Cargar el modelo y el tokenizador desde el directorio del modelo
7
+ self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
8
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
9
+ self.model.eval() # Configurar el modelo en modo de evaluaci贸n
10
+
11
+ def preprocess(self, data):
12
+ # Preprocesamiento de la entrada
13
+ if isinstance(data, dict) and "input_text" in data:
14
+ input_text = data["input_text"]
15
+ else:
16
+ raise ValueError("Esperando un diccionario con la clave 'inputs'")
17
+
18
+ # Tokenizaci贸n de la entrada
19
+ tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
20
+ return tokens
21
+
22
+ def inference(self, tokens):
23
+ # Realizar la inferencia
24
+ with torch.no_grad():
25
+ outputs = self.model(**tokens)
26
+ # Obtener las predicciones y aplicar softmax para probabilidades
27
+ probabilities = torch.softmax(outputs.logits, dim=-1)
28
+ return probabilities
29
+
30
+ def postprocess(self, probabilities):
31
+ # Postprocesamiento para devolver la salida en formato JSON
32
+ predictions = torch.argmax(probabilities, dim=-1)
33
+ return {"predictions": predictions.tolist(), "probabilities": probabilities.tolist()}
34
+
35
+ def __call__(self, data):
36
+ # Llamada principal del handler para procesamiento completo
37
+ tokens = self.preprocess(data)
38
+ probabilities = self.inference(tokens)
39
+ result = self.postprocess(probabilities)
40
+ return result