|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
import torch |
|
import json |
|
|
|
|
|
model_name = "jla25/squareV3" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, model_dir): |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_dir) |
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir) |
|
self.model.eval() |
|
|
|
def preprocess(self, data): |
|
if not isinstance(data, dict) or "inputs" not in data or data["inputs"] is None: |
|
raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor v谩lido.") |
|
|
|
|
|
input_text = f"Generate a valid JSON capturing data from this text:{data['inputs']}" |
|
print(f"Prompt generado para el modelo: {input_text}") |
|
input_text = input_text.encode("utf-8").decode("utf-8") |
|
tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1024) |
|
return tokens |
|
|
|
def inference(self, tokens): |
|
generate_kwargs = { |
|
"max_length": 512, |
|
"num_beams": 5, |
|
"do_sample": False, |
|
"temperature": 0.3, |
|
"top_k": 50, |
|
"top_p": 0.8, |
|
"early_stopping": True, |
|
"repetition_penalty": 2.5 |
|
} |
|
with torch.no_grad(): |
|
outputs = self.model.generate(**tokens, **generate_kwargs) |
|
return outputs |
|
|
|
def clean_output(self, output): |
|
try: |
|
start_index = output.index("{") |
|
end_index = output.rindex("}") + 1 |
|
return output[start_index:end_index] |
|
except ValueError: |
|
return output |
|
|
|
def postprocess(self, outputs): |
|
decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
cleaned_output = self.clean_output(decoded_output) |
|
|
|
|
|
print(f"Texto generado por el modelo: {decoded_output}") |
|
print(f"JSON limpiado: {cleaned_output}") |
|
|
|
return {"response": cleaned_output} |
|
|
|
def __call__(self, data): |
|
tokens = self.preprocess(data) |
|
outputs = self.inference(tokens) |
|
result = self.postprocess(outputs) |
|
return result |
|
|
|
|
|
|
|
handler = EndpointHandler(model_name) |
|
|