jla25 commited on
Commit
5888de7
·
verified ·
1 Parent(s): 66a0eb1

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +10 -24
handler.py CHANGED
@@ -4,8 +4,8 @@ import json
4
 
5
  class EndpointHandler:
6
  def __init__(self, model_dir):
7
- self.tokenizer = AutoTokenizer.from_pretrained("facebook/m2m100_418M")
8
- self.model = AutoModelForSeq2SeqLM.from_pretrained("facebook/m2m100_418M")
9
  self.model.eval()
10
 
11
  def preprocess(self, data):
@@ -15,37 +15,23 @@ class EndpointHandler:
15
  # Prompt personalizado para guiar al modelo
16
  input_text = (
17
  f"""
18
- Genera un JSON válido con estas especificaciones:
19
- - Cada objeto tiene una clave 'id' y un valor 'value'.
20
- - Opciones para 'id': firstName, lastName, jobTitle, address, email, phone, notes, roleFunction.
21
- - Si 'id' es address, email o phone, debe incluir subclaves: MOBILE, WORK, PERSONAL, MAIN, OTHER.
22
- - 'roleFunction' debe ser una de estas: BUYER, SELLER, SUPPLIER, PARTNER, COLLABORATOR, PROVIDER, CUSTOMER.
23
- Ejemplo:
24
- Entrada: "Contacté a Juan Pérez, Gerente de Finanzas."
25
- Salida esperada:
26
- {{
27
- "values": [
28
- {{"id": "firstName", "value": "Juan"}},
29
- {{"id": "lastName", "value": "Pérez"}},
30
- {{"id": "jobTitle", "value": "Gerente de Finanzas"}}
31
- ]
32
- }}
33
- Procesa este texto: "{data['inputs']}"
34
  """)
35
  # Imprimir el texto generado para el prompt
36
  print(f"Prompt generado para el modelo: {input_text}")
37
- tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1000)
38
  return tokens
39
 
40
  def inference(self, tokens):
41
  generate_kwargs = {
42
- "max_length": 1500,
43
- "num_beams": 7,
44
  "do_sample": False,
45
- "temperature": 0.1,
46
- "top_k": 10,
47
  "top_p": 0.7,
48
- "repetition_penalty": 2.8
49
  }
50
  with torch.no_grad():
51
  outputs = self.model.generate(**tokens, **generate_kwargs)
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, model_dir):
7
+ self.tokenizer = AutoTokenizer.from_pretrained("jla25/squareV3")
8
+ self.model = AutoModelForSeq2SeqLM.from_pretrained("jla25/squareV3")
9
  self.model.eval()
10
 
11
  def preprocess(self, data):
 
15
  # Prompt personalizado para guiar al modelo
16
  input_text = (
17
  f"""
18
+ ### Procesa el siguiente texto y genera un JSON válido:
19
+ "{data['inputs']}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  """)
21
  # Imprimir el texto generado para el prompt
22
  print(f"Prompt generado para el modelo: {input_text}")
23
+ tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1024)
24
  return tokens
25
 
26
  def inference(self, tokens):
27
  generate_kwargs = {
28
+ "max_length": 1024,
29
+ "num_beams": 5,
30
  "do_sample": False,
31
+ "temperature": 0.3,
32
+ "top_k": 50,
33
  "top_p": 0.7,
34
+ "repetition_penalty": 2.5
35
  }
36
  with torch.no_grad():
37
  outputs = self.model.generate(**tokens, **generate_kwargs)