File size: 2,488 Bytes
0d16e19 c5a16f8 0d16e19 e29f84e 62a752c 05cc433 25fd307 05cc433 0d16e19 25fd307 a87db35 0d16e19 05da434 c5a16f8 d584a44 d6ed5c7 463dbb5 4cb23e7 5888de7 0d16e19 05da434 e519cca 5888de7 1ebac2e 5888de7 d6ed5c7 e519cca 5888de7 05da434 0d16e19 05da434 0d16e19 8708518 d584a44 8708518 0d16e19 d584a44 c1975d1 d584a44 c1975d1 0d16e19 a87db35 c1975d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import json
model_name = "jla25/squareV3"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
class EndpointHandler:
def __init__(self, model_dir):
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
self.model.eval()
def preprocess(self, data):
if not isinstance(data, dict) or "inputs" not in data or data["inputs"] is None:
raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor válido.")
# Prompt personalizado para guiar al modelo
input_text = f"Generate a valid JSON capturing data from this text:{data['inputs']}"
print(f"Prompt generado para el modelo: {input_text}")
input_text = input_text.encode("utf-8").decode("utf-8")
tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1024)
return tokens
def inference(self, tokens):
generate_kwargs = {
"max_length": 512,
"num_beams": 5,
"do_sample": False,
"temperature": 0.3,
"top_k": 50,
"top_p": 0.8,
"early_stopping": True, # Añadir explicitamente esta configuración
"repetition_penalty": 2.5
}
with torch.no_grad():
outputs = self.model.generate(**tokens, **generate_kwargs)
return outputs
def clean_output(self, output):
try:
start_index = output.index("{")
end_index = output.rindex("}") + 1
return output[start_index:end_index]
except ValueError:
return output
def postprocess(self, outputs):
decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
cleaned_output = self.clean_output(decoded_output)
# Imprimir siempre el texto generado para depuración
print(f"Texto generado por el modelo: {decoded_output}")
print(f"JSON limpiado: {cleaned_output}")
return {"response": cleaned_output}
def __call__(self, data):
tokens = self.preprocess(data)
outputs = self.inference(tokens)
result = self.postprocess(outputs)
return result
# Crear una instancia del handler
handler = EndpointHandler(model_name)
|