import torch
from safetensors.torch import load_model
from models import FromZero, PreTrained
from utils import multiclass_accuracy


def cargar_evaluar_modelo(archivo, tipo_modelo, num_clases, test_dataloader):
    try:
        if tipo_modelo == "tarea_7":
            modelo = PreTrained(num_clases)
        elif tipo_modelo == "tarea_8":
            modelo = FromZero(num_clases)

        load_model(modelo, archivo)
        modelo.eval()
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        modelo.to(device)
        accuracy = 0

        with torch.no_grad():
            for imagenes, etiquetas in test_dataloader:
                imagenes = imagenes.to(device)
                etiquetas = etiquetas.to(device)
                predictions = modelo(imagenes)
                accuracy += multiclass_accuracy(predictions, etiquetas)

        accuracy = accuracy / len(test_dataloader)
        return accuracy
    except Exception as e:
        return f"Error: {str(e)}"


def evaluate_interface(model_file, model_type, num_clases, test_dataloader):
    if model_file is None:
        return "Por favor, carga un archivo .safetensor"

    # Verificamos que el archivo sea .safetensor
    if not model_file.name.endswith(".safetensor") or model_file.name.endswith(
        ".safetensors"
    ):
        return "Por favor, carga un archivo con extensión .safetensor o .safetensors"

    # Evaluamos el modelo
    accuracy = cargar_evaluar_modelo(
        model_file.name, model_type, num_clases, test_dataloader
    )

    if isinstance(accuracy, float):
        return f"Precisión del modelo: {accuracy*100:.2f}%"
    else:
        return accuracy