File size: 2,384 Bytes
0d16e19 c5a16f8 0d16e19 e29f84e 05cc433 c1975d1 05cc433 0d16e19 c1975d1 a87db35 0d16e19 05da434 c5a16f8 d584a44 c1975d1 463dbb5 c1975d1 5888de7 0d16e19 05da434 5888de7 66a0eb1 5888de7 66a0eb1 5888de7 05da434 0d16e19 05da434 0d16e19 8708518 d584a44 8708518 0d16e19 d584a44 c1975d1 d584a44 c1975d1 0d16e19 a87db35 c1975d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import json
model_name = "jla25/squareV3"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_auth_token=True)
class EndpointHandler:
def __init__(self, model_dir):
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, use_auth_token=True)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir, use_auth_token=True)
self.model.eval()
def preprocess(self, data):
if not isinstance(data, dict) or "inputs" not in data or data["inputs"] is None:
raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor válido.")
# Prompt personalizado para guiar al modelo
input_text = f"{data['inputs']}"
print(f"Prompt generado para el modelo: {input_text}")
tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1024)
return tokens
def inference(self, tokens):
generate_kwargs = {
"max_length": 1024,
"num_beams": 5,
"do_sample": False,
"temperature": 0.3,
"top_k": 50,
"top_p": 0.7,
"repetition_penalty": 2.5
}
with torch.no_grad():
outputs = self.model.generate(**tokens, **generate_kwargs)
return outputs
def clean_output(self, output):
try:
start_index = output.index("{")
end_index = output.rindex("}") + 1
return output[start_index:end_index]
except ValueError:
return output
def postprocess(self, outputs):
decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
cleaned_output = self.clean_output(decoded_output)
# Imprimir siempre el texto generado para depuración
print(f"Texto generado por el modelo: {decoded_output}")
print(f"JSON limpiado: {cleaned_output}")
return {"response": cleaned_output}
def __call__(self, data):
tokens = self.preprocess(data)
outputs = self.inference(tokens)
result = self.postprocess(outputs)
return result
# Crear una instancia del handler
handler = EndpointHandler(model_name)
|