Añadir implementación de un entorno de desarrollo y carga de modelos con evaluación de precisión
237774d
import torch | |
def cargar_etiquetas(): | |
with open("etiquetas.txt", "r") as f: | |
etiquetas = f.read().splitlines()[1:] | |
num_clases = len(etiquetas) | |
codigo = {etiqueta.lower(): i for i, etiqueta in enumerate(etiquetas)} | |
return etiquetas, num_clases, codigo | |
def multiclass_accuracy(predictions, labels): | |
# Obtén las clases predichas (la clase con la mayor probabilidad) | |
_, predicted_classes = torch.max(predictions, 1) | |
# Compara las clases predichas con las etiquetas verdaderas | |
correct_predictions = (predicted_classes == labels).sum().item() | |
# Calcula la precisión | |
accuracy = correct_predictions / labels.size(0) | |
return accuracy | |