jla25 commited on
Commit
d584a44
verified
1 Parent(s): 8708518

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +27 -41
handler.py CHANGED
@@ -1,8 +1,6 @@
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
  import torch
3
  import json
4
- import jsonschema
5
- import re
6
 
7
  class EndpointHandler:
8
  def __init__(self, model_dir):
@@ -10,32 +8,14 @@ class EndpointHandler:
10
  self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
11
  self.model.eval()
12
 
13
- # Esquema de validaci贸n del JSON
14
- self.json_schema = {
15
- "type": "object",
16
- "properties": {
17
- "values": {
18
- "type": "array",
19
- "items": {
20
- "type": "object",
21
- "properties": {
22
- "id": {"type": "string"},
23
- "value": {"type": ["string", "array"]}
24
- },
25
- "required": ["id", "value"]
26
- },
27
- },
28
- },
29
- "required": ["values"],
30
- }
31
-
32
  def preprocess(self, data):
33
  if not isinstance(data, dict) or "inputs" not in data or data["inputs"] is None:
34
  raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor v谩lido.")
35
 
 
36
  input_text = f"""
37
- Por favor, genera un JSON v谩lido basado en las siguientes especificaciones:
38
- ... (Especificaciones del formato JSON omitidas por brevedad)
39
  Procesa el siguiente texto: "{data['inputs']}"
40
  """
41
  tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1000)
@@ -56,35 +36,41 @@ class EndpointHandler:
56
  return outputs
57
 
58
  def clean_output(self, output):
59
- json_match = re.search(r"{.*}", output, re.DOTALL)
60
- if json_match:
61
- return json_match.group(0)
62
- return output
 
 
 
 
63
 
64
- def validate_json(self, decoded_output):
65
- cleaned_output = self.clean_output(decoded_output)
66
  try:
67
- json_data = json.loads(cleaned_output)
68
- jsonschema.validate(instance=json_data, schema=self.json_schema)
69
- return {"is_valid": True, "json_data": json_data}
 
 
70
  except json.JSONDecodeError as e:
71
- return {"is_valid": False, "error": f"Error decodificando JSON: {str(e)}", "raw_output": cleaned_output}
72
- except jsonschema.exceptions.ValidationError as e:
73
- return {"is_valid": False, "error": f"Error validando JSON: {str(e)}", "raw_output": cleaned_output}
74
 
75
  def postprocess(self, outputs):
76
  decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
77
-
78
- validation_result = self.validate_json(decoded_output)
79
-
80
- # Siempre imprimir la salida generada
81
  print(f"Texto generado: {decoded_output}")
 
 
 
 
82
 
83
  if not validation_result["is_valid"]:
84
  print(f"Error en la validaci贸n: {validation_result['error']}")
85
- print(f"Salida sin procesar: {validation_result.get('raw_output', 'No disponible')}")
86
  raise ValueError(f"JSON inv谩lido: {validation_result['error']}")
87
-
88
  return {"response": validation_result["json_data"]}
89
 
90
  def __call__(self, data):
 
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
  import torch
3
  import json
 
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, model_dir):
 
8
  self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
9
  self.model.eval()
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def preprocess(self, data):
12
  if not isinstance(data, dict) or "inputs" not in data or data["inputs"] is None:
13
  raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor v谩lido.")
14
 
15
+ # Prompt personalizado para guiar al modelo
16
  input_text = f"""
17
+ Genera un JSON v谩lido en el siguiente formato preentrenado:
18
+ {{\"values\": [{\"id\": \"firstName\", \"value\": \"STRING\"},{\"id\": \"lastName\", \"value\": \"STRING\"},{\"id\": \"jobTitle\", \"value\": \"STRING\"},{\"id\": \"adress\", \"value\": [{\"id\": \"[MOBILE-WORK-PERSONAL-MAIN-OTHER]\", \"value\": \"STRING\"}]},{\"id\": \"email\", \"value\": [{\"id\": \"[MOBILE-WORK-PERSONAL-MAIN-OTHER]\", \"value\": \"STRING\"}]},{\"id\": \"phone\", \"value\": [{\"id\": \"[MOBILE-WORK-PERSONAL-MAIN-OTHER]\", \"value\": \"STRING (ONLY NUMBERS)\"}]},{\"id\": \"notes\", \"value\": \"STRING\"},{\"id\": \"roleFunction\", \"value\": \"[BUYER-SELLER-SUPPLIER-PARTNER-COLLABORATOR-PROVIDER-CUSTOMER]\"}]}}
19
  Procesa el siguiente texto: "{data['inputs']}"
20
  """
21
  tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1000)
 
36
  return outputs
37
 
38
  def clean_output(self, output):
39
+ # Extraer el JSON dentro del texto generado
40
+ try:
41
+ start_index = output.index("{")
42
+ end_index = output.rindex("}") + 1
43
+ return output[start_index:end_index]
44
+ except ValueError:
45
+ # Si no hay un JSON v谩lido en el texto
46
+ return output
47
 
48
+ def validate_json(self, json_text):
49
+ # Validar el JSON generado
50
  try:
51
+ json_data = json.loads(json_text)
52
+ if "values" in json_data and isinstance(json_data["values"], list):
53
+ return {"is_valid": True, "json_data": json_data}
54
+ else:
55
+ return {"is_valid": False, "error": "El JSON no contiene el formato esperado."}
56
  except json.JSONDecodeError as e:
57
+ return {"is_valid": False, "error": f"Error decodificando JSON: {str(e)}"}
 
 
58
 
59
  def postprocess(self, outputs):
60
  decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
61
+ cleaned_output = self.clean_output(decoded_output)
62
+
63
+ # Imprimir siempre el texto generado para depuraci贸n
 
64
  print(f"Texto generado: {decoded_output}")
65
+ print(f"JSON limpiado: {cleaned_output}")
66
+
67
+ # Validar el JSON generado
68
+ validation_result = self.validate_json(cleaned_output)
69
 
70
  if not validation_result["is_valid"]:
71
  print(f"Error en la validaci贸n: {validation_result['error']}")
 
72
  raise ValueError(f"JSON inv谩lido: {validation_result['error']}")
73
+
74
  return {"response": validation_result["json_data"]}
75
 
76
  def __call__(self, data):