jla25 commited on
Commit
8708518
verified
1 Parent(s): c5a16f8

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +18 -66
handler.py CHANGED
@@ -2,6 +2,7 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
  import torch
3
  import json
4
  import jsonschema
 
5
 
6
  class EndpointHandler:
7
  def __init__(self, model_dir):
@@ -29,76 +30,18 @@ class EndpointHandler:
29
  }
30
 
31
  def preprocess(self, data):
32
- # Validar la entrada
33
  if not isinstance(data, dict) or "inputs" not in data or data["inputs"] is None:
34
  raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor v谩lido.")
35
 
36
- # Construir el prompt con el formato especificado
37
  input_text = f"""
38
  Por favor, genera un JSON v谩lido basado en las siguientes especificaciones:
39
-
40
- Formato esperado:
41
- {{
42
- "values": [
43
- {{
44
- "id": "firstName",
45
- "value": "STRING"
46
- }},
47
- {{
48
- "id": "lastName",
49
- "value": "STRING"
50
- }},
51
- {{
52
- "id": "jobTitle",
53
- "value": "STRING"
54
- }},
55
- {{
56
- "id": "adress",
57
- "value": [
58
- {{
59
- "id": "[MOBILE-WORK-PERSONAL-MAIN-OTHER]",
60
- "value": "STRING"
61
- }}
62
- ]
63
- }},
64
- {{
65
- "id": "email",
66
- "value": [
67
- {{
68
- "id": "[MOBILE-WORK-PERSONAL-MAIN-OTHER]",
69
- "value": "STRING"
70
- }}
71
- ]
72
- }},
73
- {{
74
- "id": "phone",
75
- "value": [
76
- {{
77
- "id": "[MOBILE-WORK-PERSONAL-MAIN-OTHER]",
78
- "value": "STRING (ONLY NUMBERS)"
79
- }}
80
- ]
81
- }},
82
- {{
83
- "id": "notes",
84
- "value": "STRING"
85
- }},
86
- {{
87
- "id": "roleFunction",
88
- "value": "[BUYER-SELLER-SUPPLIER-PARTNER-COLLABORATOR-PROVIDER-CUSTOMER]"
89
- }}
90
- ]
91
- }}
92
-
93
- Solo incluye los campos detectados en el texto de entrada.
94
  Procesa el siguiente texto: "{data['inputs']}"
95
  """
96
- # Tokenizar el texto de entrada
97
  tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1000)
98
  return tokens
99
 
100
  def inference(self, tokens):
101
- # Par谩metros de generaci贸n
102
  generate_kwargs = {
103
  "max_length": 1000,
104
  "num_beams": 5,
@@ -108,29 +51,38 @@ class EndpointHandler:
108
  "top_p": 0.9,
109
  "repetition_penalty": 2.5
110
  }
111
- # Generar salida con el modelo
112
  with torch.no_grad():
113
  outputs = self.model.generate(**tokens, **generate_kwargs)
114
  return outputs
115
 
 
 
 
 
 
 
116
  def validate_json(self, decoded_output):
117
- # Validar el JSON generado con el esquema
118
  try:
119
- json_data = json.loads(decoded_output)
120
  jsonschema.validate(instance=json_data, schema=self.json_schema)
121
  return {"is_valid": True, "json_data": json_data}
122
  except json.JSONDecodeError as e:
123
- return {"is_valid": False, "error": f"Error decodificando JSON: {str(e)}"}
124
  except jsonschema.exceptions.ValidationError as e:
125
- return {"is_valid": False, "error": f"Error validando JSON: {str(e)}"}
126
 
127
  def postprocess(self, outputs):
128
- # Decodificar la salida generada
129
  decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
130
 
131
- # Validar el JSON generado
132
  validation_result = self.validate_json(decoded_output)
 
 
 
 
133
  if not validation_result["is_valid"]:
 
 
134
  raise ValueError(f"JSON inv谩lido: {validation_result['error']}")
135
 
136
  return {"response": validation_result["json_data"]}
 
2
  import torch
3
  import json
4
  import jsonschema
5
+ import re
6
 
7
  class EndpointHandler:
8
  def __init__(self, model_dir):
 
30
  }
31
 
32
  def preprocess(self, data):
 
33
  if not isinstance(data, dict) or "inputs" not in data or data["inputs"] is None:
34
  raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor v谩lido.")
35
 
 
36
  input_text = f"""
37
  Por favor, genera un JSON v谩lido basado en las siguientes especificaciones:
38
+ ... (Especificaciones del formato JSON omitidas por brevedad)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  Procesa el siguiente texto: "{data['inputs']}"
40
  """
 
41
  tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1000)
42
  return tokens
43
 
44
  def inference(self, tokens):
 
45
  generate_kwargs = {
46
  "max_length": 1000,
47
  "num_beams": 5,
 
51
  "top_p": 0.9,
52
  "repetition_penalty": 2.5
53
  }
 
54
  with torch.no_grad():
55
  outputs = self.model.generate(**tokens, **generate_kwargs)
56
  return outputs
57
 
58
+ def clean_output(self, output):
59
+ json_match = re.search(r"{.*}", output, re.DOTALL)
60
+ if json_match:
61
+ return json_match.group(0)
62
+ return output
63
+
64
  def validate_json(self, decoded_output):
65
+ cleaned_output = self.clean_output(decoded_output)
66
  try:
67
+ json_data = json.loads(cleaned_output)
68
  jsonschema.validate(instance=json_data, schema=self.json_schema)
69
  return {"is_valid": True, "json_data": json_data}
70
  except json.JSONDecodeError as e:
71
+ return {"is_valid": False, "error": f"Error decodificando JSON: {str(e)}", "raw_output": cleaned_output}
72
  except jsonschema.exceptions.ValidationError as e:
73
+ return {"is_valid": False, "error": f"Error validando JSON: {str(e)}", "raw_output": cleaned_output}
74
 
75
  def postprocess(self, outputs):
 
76
  decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
77
 
 
78
  validation_result = self.validate_json(decoded_output)
79
+
80
+ # Siempre imprimir la salida generada
81
+ print(f"Texto generado: {decoded_output}")
82
+
83
  if not validation_result["is_valid"]:
84
+ print(f"Error en la validaci贸n: {validation_result['error']}")
85
+ print(f"Salida sin procesar: {validation_result.get('raw_output', 'No disponible')}")
86
  raise ValueError(f"JSON inv谩lido: {validation_result['error']}")
87
 
88
  return {"response": validation_result["json_data"]}