jla25 commited on
Commit
c1975d1
verified
1 Parent(s): e29f84e

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +12 -28
handler.py CHANGED
@@ -5,14 +5,14 @@ import json
5
 
6
  model_name = "jla25/squareV3"
7
 
8
- tokenizer = AutoTokenizer.from_pretrained(model_name)
9
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
10
 
11
 
12
  class EndpointHandler:
13
  def __init__(self, model_dir):
14
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
15
- self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
16
  self.model.eval()
17
 
18
  def preprocess(self, data):
@@ -20,9 +20,9 @@ class EndpointHandler:
20
  raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor v谩lido.")
21
 
22
  # Prompt personalizado para guiar al modelo
23
- input_text = ({data['inputs']})
24
- # Imprimir el texto generado para el prompt
25
  print(f"Prompt generado para el modelo: {input_text}")
 
26
  tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1024)
27
  return tokens
28
 
@@ -41,45 +41,29 @@ class EndpointHandler:
41
  return outputs
42
 
43
  def clean_output(self, output):
44
- # Extraer el JSON dentro del texto generado
45
  try:
46
  start_index = output.index("{")
47
  end_index = output.rindex("}") + 1
48
  return output[start_index:end_index]
49
  except ValueError:
50
- # Si no hay un JSON v谩lido en el texto
51
  return output
52
 
53
- def validate_json(self, json_text):
54
- # Validar el JSON generado
55
- try:
56
- json_data = json.loads(json_text)
57
- if "values" in json_data and isinstance(json_data["values"], list):
58
- return {"is_valid": True, "json_data": json_data}
59
- else:
60
- return {"is_valid": False, "error": "El JSON no contiene el formato esperado."}
61
- except json.JSONDecodeError as e:
62
- return {"is_valid": False, "error": f"Error decodificando JSON: {str(e)}"}
63
-
64
  def postprocess(self, outputs):
65
  decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
66
  cleaned_output = self.clean_output(decoded_output)
67
 
68
  # Imprimir siempre el texto generado para depuraci贸n
69
- print(f"Texto generado: {decoded_output}")
70
  print(f"JSON limpiado: {cleaned_output}")
71
 
72
- # Validar el JSON generado
73
- validation_result = self.validate_json(cleaned_output)
74
-
75
- if not validation_result["is_valid"]:
76
- print(f"Error en la validaci贸n: {validation_result['error']}")
77
- raise ValueError(f"JSON inv谩lido: {validation_result['error']}")
78
-
79
- return {"response": validation_result["json_data"]}
80
 
81
  def __call__(self, data):
82
  tokens = self.preprocess(data)
83
  outputs = self.inference(tokens)
84
  result = self.postprocess(outputs)
85
  return result
 
 
 
 
 
5
 
6
  model_name = "jla25/squareV3"
7
 
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True)
9
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_auth_token=True)
10
 
11
 
12
  class EndpointHandler:
13
  def __init__(self, model_dir):
14
+ self.tokenizer = AutoTokenizer.from_pretrained(model_dir, use_auth_token=True)
15
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir, use_auth_token=True)
16
  self.model.eval()
17
 
18
  def preprocess(self, data):
 
20
  raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor v谩lido.")
21
 
22
  # Prompt personalizado para guiar al modelo
23
+ input_text = f"{data['inputs']}"
 
24
  print(f"Prompt generado para el modelo: {input_text}")
25
+
26
  tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1024)
27
  return tokens
28
 
 
41
  return outputs
42
 
43
  def clean_output(self, output):
 
44
  try:
45
  start_index = output.index("{")
46
  end_index = output.rindex("}") + 1
47
  return output[start_index:end_index]
48
  except ValueError:
 
49
  return output
50
 
 
 
 
 
 
 
 
 
 
 
 
51
  def postprocess(self, outputs):
52
  decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
53
  cleaned_output = self.clean_output(decoded_output)
54
 
55
  # Imprimir siempre el texto generado para depuraci贸n
56
+ print(f"Texto generado por el modelo: {decoded_output}")
57
  print(f"JSON limpiado: {cleaned_output}")
58
 
59
+ return {"response": cleaned_output}
 
 
 
 
 
 
 
60
 
61
  def __call__(self, data):
62
  tokens = self.preprocess(data)
63
  outputs = self.inference(tokens)
64
  result = self.postprocess(outputs)
65
  return result
66
+
67
+
68
+ # Crear una instancia del handler
69
+ handler = EndpointHandler(model_name)