Update handler.py
Browse files- handler.py +12 -28
handler.py
CHANGED
@@ -5,14 +5,14 @@ import json
|
|
5 |
|
6 |
model_name = "jla25/squareV3"
|
7 |
|
8 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
9 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
10 |
|
11 |
|
12 |
class EndpointHandler:
|
13 |
def __init__(self, model_dir):
|
14 |
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
15 |
-
self.model = AutoModelForSeq2SeqLM.from_pretrained(
|
16 |
self.model.eval()
|
17 |
|
18 |
def preprocess(self, data):
|
@@ -20,9 +20,9 @@ class EndpointHandler:
|
|
20 |
raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor v谩lido.")
|
21 |
|
22 |
# Prompt personalizado para guiar al modelo
|
23 |
-
input_text =
|
24 |
-
# Imprimir el texto generado para el prompt
|
25 |
print(f"Prompt generado para el modelo: {input_text}")
|
|
|
26 |
tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1024)
|
27 |
return tokens
|
28 |
|
@@ -41,45 +41,29 @@ class EndpointHandler:
|
|
41 |
return outputs
|
42 |
|
43 |
def clean_output(self, output):
|
44 |
-
# Extraer el JSON dentro del texto generado
|
45 |
try:
|
46 |
start_index = output.index("{")
|
47 |
end_index = output.rindex("}") + 1
|
48 |
return output[start_index:end_index]
|
49 |
except ValueError:
|
50 |
-
# Si no hay un JSON v谩lido en el texto
|
51 |
return output
|
52 |
|
53 |
-
def validate_json(self, json_text):
|
54 |
-
# Validar el JSON generado
|
55 |
-
try:
|
56 |
-
json_data = json.loads(json_text)
|
57 |
-
if "values" in json_data and isinstance(json_data["values"], list):
|
58 |
-
return {"is_valid": True, "json_data": json_data}
|
59 |
-
else:
|
60 |
-
return {"is_valid": False, "error": "El JSON no contiene el formato esperado."}
|
61 |
-
except json.JSONDecodeError as e:
|
62 |
-
return {"is_valid": False, "error": f"Error decodificando JSON: {str(e)}"}
|
63 |
-
|
64 |
def postprocess(self, outputs):
|
65 |
decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
66 |
cleaned_output = self.clean_output(decoded_output)
|
67 |
|
68 |
# Imprimir siempre el texto generado para depuraci贸n
|
69 |
-
print(f"Texto generado: {decoded_output}")
|
70 |
print(f"JSON limpiado: {cleaned_output}")
|
71 |
|
72 |
-
|
73 |
-
validation_result = self.validate_json(cleaned_output)
|
74 |
-
|
75 |
-
if not validation_result["is_valid"]:
|
76 |
-
print(f"Error en la validaci贸n: {validation_result['error']}")
|
77 |
-
raise ValueError(f"JSON inv谩lido: {validation_result['error']}")
|
78 |
-
|
79 |
-
return {"response": validation_result["json_data"]}
|
80 |
|
81 |
def __call__(self, data):
|
82 |
tokens = self.preprocess(data)
|
83 |
outputs = self.inference(tokens)
|
84 |
result = self.postprocess(outputs)
|
85 |
return result
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
model_name = "jla25/squareV3"
|
7 |
|
8 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True)
|
9 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_auth_token=True)
|
10 |
|
11 |
|
12 |
class EndpointHandler:
|
13 |
def __init__(self, model_dir):
|
14 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, use_auth_token=True)
|
15 |
+
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir, use_auth_token=True)
|
16 |
self.model.eval()
|
17 |
|
18 |
def preprocess(self, data):
|
|
|
20 |
raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor v谩lido.")
|
21 |
|
22 |
# Prompt personalizado para guiar al modelo
|
23 |
+
input_text = f"{data['inputs']}"
|
|
|
24 |
print(f"Prompt generado para el modelo: {input_text}")
|
25 |
+
|
26 |
tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1024)
|
27 |
return tokens
|
28 |
|
|
|
41 |
return outputs
|
42 |
|
43 |
def clean_output(self, output):
|
|
|
44 |
try:
|
45 |
start_index = output.index("{")
|
46 |
end_index = output.rindex("}") + 1
|
47 |
return output[start_index:end_index]
|
48 |
except ValueError:
|
|
|
49 |
return output
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
def postprocess(self, outputs):
|
52 |
decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
53 |
cleaned_output = self.clean_output(decoded_output)
|
54 |
|
55 |
# Imprimir siempre el texto generado para depuraci贸n
|
56 |
+
print(f"Texto generado por el modelo: {decoded_output}")
|
57 |
print(f"JSON limpiado: {cleaned_output}")
|
58 |
|
59 |
+
return {"response": cleaned_output}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
def __call__(self, data):
|
62 |
tokens = self.preprocess(data)
|
63 |
outputs = self.inference(tokens)
|
64 |
result = self.postprocess(outputs)
|
65 |
return result
|
66 |
+
|
67 |
+
|
68 |
+
# Crear una instancia del handler
|
69 |
+
handler = EndpointHandler(model_name)
|