Update handler.py
Browse files- handler.py +10 -24
handler.py
CHANGED
@@ -4,8 +4,8 @@ import json
|
|
4 |
|
5 |
class EndpointHandler:
|
6 |
def __init__(self, model_dir):
|
7 |
-
self.tokenizer = AutoTokenizer.from_pretrained("
|
8 |
-
self.model = AutoModelForSeq2SeqLM.from_pretrained("
|
9 |
self.model.eval()
|
10 |
|
11 |
def preprocess(self, data):
|
@@ -15,37 +15,23 @@ class EndpointHandler:
|
|
15 |
# Prompt personalizado para guiar al modelo
|
16 |
input_text = (
|
17 |
f"""
|
18 |
-
|
19 |
-
|
20 |
-
- Opciones para 'id': firstName, lastName, jobTitle, address, email, phone, notes, roleFunction.
|
21 |
-
- Si 'id' es address, email o phone, debe incluir subclaves: MOBILE, WORK, PERSONAL, MAIN, OTHER.
|
22 |
-
- 'roleFunction' debe ser una de estas: BUYER, SELLER, SUPPLIER, PARTNER, COLLABORATOR, PROVIDER, CUSTOMER.
|
23 |
-
Ejemplo:
|
24 |
-
Entrada: "Contacté a Juan Pérez, Gerente de Finanzas."
|
25 |
-
Salida esperada:
|
26 |
-
{{
|
27 |
-
"values": [
|
28 |
-
{{"id": "firstName", "value": "Juan"}},
|
29 |
-
{{"id": "lastName", "value": "Pérez"}},
|
30 |
-
{{"id": "jobTitle", "value": "Gerente de Finanzas"}}
|
31 |
-
]
|
32 |
-
}}
|
33 |
-
Procesa este texto: "{data['inputs']}"
|
34 |
""")
|
35 |
# Imprimir el texto generado para el prompt
|
36 |
print(f"Prompt generado para el modelo: {input_text}")
|
37 |
-
tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=
|
38 |
return tokens
|
39 |
|
40 |
def inference(self, tokens):
|
41 |
generate_kwargs = {
|
42 |
-
"max_length":
|
43 |
-
"num_beams":
|
44 |
"do_sample": False,
|
45 |
-
"temperature": 0.
|
46 |
-
"top_k":
|
47 |
"top_p": 0.7,
|
48 |
-
"repetition_penalty": 2.
|
49 |
}
|
50 |
with torch.no_grad():
|
51 |
outputs = self.model.generate(**tokens, **generate_kwargs)
|
|
|
4 |
|
5 |
class EndpointHandler:
|
6 |
def __init__(self, model_dir):
|
7 |
+
self.tokenizer = AutoTokenizer.from_pretrained("jla25/squareV3")
|
8 |
+
self.model = AutoModelForSeq2SeqLM.from_pretrained("jla25/squareV3")
|
9 |
self.model.eval()
|
10 |
|
11 |
def preprocess(self, data):
|
|
|
15 |
# Prompt personalizado para guiar al modelo
|
16 |
input_text = (
|
17 |
f"""
|
18 |
+
### Procesa el siguiente texto y genera un JSON válido:
|
19 |
+
"{data['inputs']}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
""")
|
21 |
# Imprimir el texto generado para el prompt
|
22 |
print(f"Prompt generado para el modelo: {input_text}")
|
23 |
+
tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1024)
|
24 |
return tokens
|
25 |
|
26 |
def inference(self, tokens):
|
27 |
generate_kwargs = {
|
28 |
+
"max_length": 1024,
|
29 |
+
"num_beams": 5,
|
30 |
"do_sample": False,
|
31 |
+
"temperature": 0.3,
|
32 |
+
"top_k": 50,
|
33 |
"top_p": 0.7,
|
34 |
+
"repetition_penalty": 2.5
|
35 |
}
|
36 |
with torch.no_grad():
|
37 |
outputs = self.model.generate(**tokens, **generate_kwargs)
|