Spaces:
Runtime error
Runtime error
import torch | |
from torchvision import transforms, datasets | |
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights | |
from torchvision.transforms.functional import InterpolationMode | |
from PIL import Image | |
import gradio as gr | |
# Define el modelo y carga los pesos guardados | |
model = efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT) | |
model.classifier[1] = torch.nn.Linear(in_features=1280, out_features=101) | |
#model.load_state_dict(torch.load('./Model_Food_ProyectoIA'), map_location=torch.device('cpu')) | |
model.load_state_dict(torch.load('./Model_Food_ProyectoIA', map_location=torch.device('cpu'))) | |
model.eval() # Poner el modelo en modo evaluaci贸n | |
# Mueve el modelo a la GPU si est谩 disponible | |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
model.to(device) | |
# Define las transformaciones | |
transform_preprocess = transforms.Compose([ | |
transforms.Resize(256, interpolation=InterpolationMode.BICUBIC), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
# Cargar el conjunto de datos Food-101 para obtener la lista de clases | |
#image_path = '../../compartida/vision-project/' | |
#food101_dataset = datasets.Food101(image_path, split='train') | |
#classes = food101_dataset.classes | |
# Funci贸n para cargar el archivo de texto desde una URL | |
def load_remote_dataset(url): | |
response = requests.get(url) | |
response.raise_for_status() # Aseg煤rate de que la solicitud fue exitosa | |
return response.text | |
#da error el load_remote_dataset | |
#from datasets import load_remote_dataset | |
# URL del archivo de texto en el Space de Hugging Face | |
file_url = "https://huggingface.co/spaces/Alan7/ProyectoComputerVision/blob/main/clases.txt" | |
# Carga el archivo de texto desde la URL | |
file_content = load_remote_dataset(file_url) | |
# Lee el contenido del archivo y divide por saltos de l铆nea | |
classes = file_content.strip().split("\n") | |
# Funci贸n para predecir la clase de una nueva imagen | |
def predict_image(image): | |
image = Image.fromarray(image).convert('RGB') # Convertir la imagen cargada a PIL | |
image = transform_preprocess(image).unsqueeze(0) # Preprocesar y a帽adir dimensi贸n de batch | |
image = image.to(device) # Mover la imagen a la GPU si est谩 disponible | |
with torch.no_grad(): | |
output = model(image) # Realizar la predicci贸n | |
prediction = torch.nn.functional.softmax(output[0], dim=0) # Aplicar softmax para obtener probabilidades | |
confidences = {classes[i]: float(prediction[i]) for i in range(101)} # Crear diccionario de clases y probabilidades | |
return confidences # Devolver las probabilidades de cada clase | |
# Crear la interfaz de Gradio | |
interface = gr.Interface( | |
fn=predict_image, | |
inputs=gr.Image(type="numpy"), | |
outputs=gr.Label(num_top_classes=3), | |
title="Food101 Classifier", | |
description="Sube una imagen de comida y el modelo clasificar谩 la imagen.", | |
examples=["https://www.cnature.es/receta/receta-de-hamburguesa-con-guacamole/"] # Reemplaza con rutas de ejemplo | |
) | |
# Iniciar la interfaz | |
interface.launch(share=True) |