jla25 commited on
Commit
c5a16f8
verified
1 Parent(s): 333ea63

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +100 -3
handler.py CHANGED
@@ -1,5 +1,7 @@
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
  import torch
 
 
3
 
4
  class EndpointHandler:
5
  def __init__(self, model_dir):
@@ -7,13 +9,91 @@ class EndpointHandler:
7
  self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
8
  self.model.eval()
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def preprocess(self, data):
11
  # Validar la entrada
12
  if not isinstance(data, dict) or "inputs" not in data or data["inputs"] is None:
13
  raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor v谩lido.")
14
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # Tokenizar el texto de entrada
16
- input_text = "Generate a valid JSON capturing data from this text: " + data["inputs"]
17
  tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1000)
18
  return tokens
19
 
@@ -33,10 +113,27 @@ class EndpointHandler:
33
  outputs = self.model.generate(**tokens, **generate_kwargs)
34
  return outputs
35
 
 
 
 
 
 
 
 
 
 
 
 
36
  def postprocess(self, outputs):
37
  # Decodificar la salida generada
38
  decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
39
- return {"response": decoded_output}
 
 
 
 
 
 
40
 
41
  def __call__(self, data):
42
  tokens = self.preprocess(data)
 
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
  import torch
3
+ import json
4
+ import jsonschema
5
 
6
  class EndpointHandler:
7
  def __init__(self, model_dir):
 
9
  self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
10
  self.model.eval()
11
 
12
+ # Esquema de validaci贸n del JSON
13
+ self.json_schema = {
14
+ "type": "object",
15
+ "properties": {
16
+ "values": {
17
+ "type": "array",
18
+ "items": {
19
+ "type": "object",
20
+ "properties": {
21
+ "id": {"type": "string"},
22
+ "value": {"type": ["string", "array"]}
23
+ },
24
+ "required": ["id", "value"]
25
+ },
26
+ },
27
+ },
28
+ "required": ["values"],
29
+ }
30
+
31
  def preprocess(self, data):
32
  # Validar la entrada
33
  if not isinstance(data, dict) or "inputs" not in data or data["inputs"] is None:
34
  raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor v谩lido.")
35
+
36
+ # Construir el prompt con el formato especificado
37
+ input_text = f"""
38
+ Por favor, genera un JSON v谩lido basado en las siguientes especificaciones:
39
+
40
+ Formato esperado:
41
+ {{
42
+ "values": [
43
+ {{
44
+ "id": "firstName",
45
+ "value": "STRING"
46
+ }},
47
+ {{
48
+ "id": "lastName",
49
+ "value": "STRING"
50
+ }},
51
+ {{
52
+ "id": "jobTitle",
53
+ "value": "STRING"
54
+ }},
55
+ {{
56
+ "id": "adress",
57
+ "value": [
58
+ {{
59
+ "id": "[MOBILE-WORK-PERSONAL-MAIN-OTHER]",
60
+ "value": "STRING"
61
+ }}
62
+ ]
63
+ }},
64
+ {{
65
+ "id": "email",
66
+ "value": [
67
+ {{
68
+ "id": "[MOBILE-WORK-PERSONAL-MAIN-OTHER]",
69
+ "value": "STRING"
70
+ }}
71
+ ]
72
+ }},
73
+ {{
74
+ "id": "phone",
75
+ "value": [
76
+ {{
77
+ "id": "[MOBILE-WORK-PERSONAL-MAIN-OTHER]",
78
+ "value": "STRING (ONLY NUMBERS)"
79
+ }}
80
+ ]
81
+ }},
82
+ {{
83
+ "id": "notes",
84
+ "value": "STRING"
85
+ }},
86
+ {{
87
+ "id": "roleFunction",
88
+ "value": "[BUYER-SELLER-SUPPLIER-PARTNER-COLLABORATOR-PROVIDER-CUSTOMER]"
89
+ }}
90
+ ]
91
+ }}
92
+
93
+ Solo incluye los campos detectados en el texto de entrada.
94
+ Procesa el siguiente texto: "{data['inputs']}"
95
+ """
96
  # Tokenizar el texto de entrada
 
97
  tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1000)
98
  return tokens
99
 
 
113
  outputs = self.model.generate(**tokens, **generate_kwargs)
114
  return outputs
115
 
116
+ def validate_json(self, decoded_output):
117
+ # Validar el JSON generado con el esquema
118
+ try:
119
+ json_data = json.loads(decoded_output)
120
+ jsonschema.validate(instance=json_data, schema=self.json_schema)
121
+ return {"is_valid": True, "json_data": json_data}
122
+ except json.JSONDecodeError as e:
123
+ return {"is_valid": False, "error": f"Error decodificando JSON: {str(e)}"}
124
+ except jsonschema.exceptions.ValidationError as e:
125
+ return {"is_valid": False, "error": f"Error validando JSON: {str(e)}"}
126
+
127
  def postprocess(self, outputs):
128
  # Decodificar la salida generada
129
  decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
130
+
131
+ # Validar el JSON generado
132
+ validation_result = self.validate_json(decoded_output)
133
+ if not validation_result["is_valid"]:
134
+ raise ValueError(f"JSON inv谩lido: {validation_result['error']}")
135
+
136
+ return {"response": validation_result["json_data"]}
137
 
138
  def __call__(self, data):
139
  tokens = self.preprocess(data)