jla25 commited on
Commit
11fa52a
verified
1 Parent(s): 1bfbfc2

basic test

Browse files
Files changed (1) hide show
  1. handler.py +4 -36
handler.py CHANGED
@@ -1,39 +1,7 @@
1
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
- import torch
3
- import json
4
-
5
- class EndpointHandler: # Aseg煤rate de que el nombre sea EndpointHandler
6
  def __init__(self, model_dir):
7
- # Cargar el modelo y el tokenizador desde el directorio del modelo
8
- self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
9
- self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
10
- self.model.eval() # Configurar el modelo en modo de evaluaci贸n
11
-
12
- def preprocess(self, data):
13
- # Preprocesamiento de la entrada
14
- if isinstance(data, dict) and "input_text" in data:
15
- input_text = data["input_text"]
16
- else:
17
- raise ValueError("Esperando un diccionario con la clave 'input_text'")
18
-
19
- # Tokenizaci贸n de la entrada
20
- tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
21
- return tokens
22
-
23
- def inference(self, tokens):
24
- # Realizar la inferencia
25
- with torch.no_grad():
26
- outputs = self.model.generate(**tokens)
27
- return outputs
28
-
29
- def postprocess(self, outputs):
30
- # Decodificar la salida del modelo
31
- decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
32
- return {"generated_text": decoded_output}
33
 
34
  def __call__(self, data):
35
- # Llamada principal del handler para procesamiento completo
36
- tokens = self.preprocess(data)
37
- outputs = self.inference(tokens)
38
- result = self.postprocess(outputs)
39
- return result
 
1
+ class EndpointHandler:
 
 
 
 
2
  def __init__(self, model_dir):
3
+ print("Inicializando el modelo")
4
+ self.model_dir = model_dir
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  def __call__(self, data):
7
+ return {"message": "Este es un mensaje de prueba"}