File size: 3,200 Bytes
0d16e19 c5a16f8 0d16e19 e29f84e 05cc433 e29f84e 05cc433 0d16e19 05cc433 a87db35 0d16e19 05da434 c5a16f8 d584a44 2433496 463dbb5 5888de7 0d16e19 05da434 5888de7 66a0eb1 5888de7 66a0eb1 5888de7 05da434 0d16e19 05da434 0d16e19 8708518 d584a44 8708518 d584a44 c5a16f8 d584a44 c5a16f8 d584a44 c5a16f8 0d16e19 d584a44 8708518 d584a44 8708518 c5a16f8 8708518 c5a16f8 d584a44 c5a16f8 0d16e19 a87db35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import json
model_name = "jla25/squareV3"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
class EndpointHandler:
def __init__(self, model_dir):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
self.model.eval()
def preprocess(self, data):
if not isinstance(data, dict) or "inputs" not in data or data["inputs"] is None:
raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor válido.")
# Prompt personalizado para guiar al modelo
input_text = ({data['inputs']})
# Imprimir el texto generado para el prompt
print(f"Prompt generado para el modelo: {input_text}")
tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1024)
return tokens
def inference(self, tokens):
generate_kwargs = {
"max_length": 1024,
"num_beams": 5,
"do_sample": False,
"temperature": 0.3,
"top_k": 50,
"top_p": 0.7,
"repetition_penalty": 2.5
}
with torch.no_grad():
outputs = self.model.generate(**tokens, **generate_kwargs)
return outputs
def clean_output(self, output):
# Extraer el JSON dentro del texto generado
try:
start_index = output.index("{")
end_index = output.rindex("}") + 1
return output[start_index:end_index]
except ValueError:
# Si no hay un JSON válido en el texto
return output
def validate_json(self, json_text):
# Validar el JSON generado
try:
json_data = json.loads(json_text)
if "values" in json_data and isinstance(json_data["values"], list):
return {"is_valid": True, "json_data": json_data}
else:
return {"is_valid": False, "error": "El JSON no contiene el formato esperado."}
except json.JSONDecodeError as e:
return {"is_valid": False, "error": f"Error decodificando JSON: {str(e)}"}
def postprocess(self, outputs):
decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
cleaned_output = self.clean_output(decoded_output)
# Imprimir siempre el texto generado para depuración
print(f"Texto generado: {decoded_output}")
print(f"JSON limpiado: {cleaned_output}")
# Validar el JSON generado
validation_result = self.validate_json(cleaned_output)
if not validation_result["is_valid"]:
print(f"Error en la validación: {validation_result['error']}")
raise ValueError(f"JSON inválido: {validation_result['error']}")
return {"response": validation_result["json_data"]}
def __call__(self, data):
tokens = self.preprocess(data)
outputs = self.inference(tokens)
result = self.postprocess(outputs)
return result
|