jla25 commited on
Commit
333ea63
verified
1 Parent(s): 4722f73

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +6 -5
handler.py CHANGED
@@ -8,16 +8,17 @@ class EndpointHandler:
8
  self.model.eval()
9
 
10
  def preprocess(self, data):
 
11
  if not isinstance(data, dict) or "inputs" not in data or data["inputs"] is None:
12
  raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor v谩lido.")
13
 
 
14
  input_text = "Generate a valid JSON capturing data from this text: " + data["inputs"]
15
  tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1000)
16
- if not tokens or not tokens["input_ids"]:
17
- raise ValueError("Error al tokenizar el texto de entrada. Verifica el texto.")
18
  return tokens
19
 
20
  def inference(self, tokens):
 
21
  generate_kwargs = {
22
  "max_length": 1000,
23
  "num_beams": 5,
@@ -27,15 +28,15 @@ class EndpointHandler:
27
  "top_p": 0.9,
28
  "repetition_penalty": 2.5
29
  }
 
30
  with torch.no_grad():
31
  outputs = self.model.generate(**tokens, **generate_kwargs)
32
- if outputs is None or len(outputs) == 0:
33
- raise ValueError("El modelo no gener贸 ninguna salida.")
34
  return outputs
35
 
36
  def postprocess(self, outputs):
 
37
  decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
38
- return {"generated_text": decoded_output}
39
 
40
  def __call__(self, data):
41
  tokens = self.preprocess(data)
 
8
  self.model.eval()
9
 
10
  def preprocess(self, data):
11
+ # Validar la entrada
12
  if not isinstance(data, dict) or "inputs" not in data or data["inputs"] is None:
13
  raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor v谩lido.")
14
 
15
+ # Tokenizar el texto de entrada
16
  input_text = "Generate a valid JSON capturing data from this text: " + data["inputs"]
17
  tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1000)
 
 
18
  return tokens
19
 
20
  def inference(self, tokens):
21
+ # Par谩metros de generaci贸n
22
  generate_kwargs = {
23
  "max_length": 1000,
24
  "num_beams": 5,
 
28
  "top_p": 0.9,
29
  "repetition_penalty": 2.5
30
  }
31
+ # Generar salida con el modelo
32
  with torch.no_grad():
33
  outputs = self.model.generate(**tokens, **generate_kwargs)
 
 
34
  return outputs
35
 
36
  def postprocess(self, outputs):
37
+ # Decodificar la salida generada
38
  decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
39
+ return {"response": decoded_output}
40
 
41
  def __call__(self, data):
42
  tokens = self.preprocess(data)