Alan7's picture
Update app.py
403213e verified
raw
history blame
3.13 kB
import torch
from torchvision import transforms, datasets
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
from torchvision.transforms.functional import InterpolationMode
from PIL import Image
import gradio as gr
# Define el modelo y carga los pesos guardados
model = efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT)
model.classifier[1] = torch.nn.Linear(in_features=1280, out_features=101)
#model.load_state_dict(torch.load('./Model_Food_ProyectoIA'), map_location=torch.device('cpu'))
model.load_state_dict(torch.load('./Model_Food_ProyectoIA', map_location=torch.device('cpu')))
model.eval() # Poner el modelo en modo evaluaci贸n
# Mueve el modelo a la GPU si est谩 disponible
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)
# Define las transformaciones
transform_preprocess = transforms.Compose([
transforms.Resize(256, interpolation=InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Cargar el conjunto de datos Food-101 para obtener la lista de clases
#image_path = '../../compartida/vision-project/'
#food101_dataset = datasets.Food101(image_path, split='train')
#classes = food101_dataset.classes
# Funci贸n para cargar el archivo de texto desde una URL
def load_remote_dataset(url):
response = requests.get(url)
response.raise_for_status() # Aseg煤rate de que la solicitud fue exitosa
return response.text
#da error el load_remote_dataset
#from datasets import load_remote_dataset
# URL del archivo de texto en el Space de Hugging Face
file_url = "https://huggingface.co/spaces/Alan7/ProyectoComputerVision/blob/main/clases.txt"
# Carga el archivo de texto desde la URL
file_content = load_remote_dataset(file_url)
# Lee el contenido del archivo y divide por saltos de l铆nea
classes = file_content.strip().split("\n")
# Funci贸n para predecir la clase de una nueva imagen
def predict_image(image):
image = Image.fromarray(image).convert('RGB') # Convertir la imagen cargada a PIL
image = transform_preprocess(image).unsqueeze(0) # Preprocesar y a帽adir dimensi贸n de batch
image = image.to(device) # Mover la imagen a la GPU si est谩 disponible
with torch.no_grad():
output = model(image) # Realizar la predicci贸n
prediction = torch.nn.functional.softmax(output[0], dim=0) # Aplicar softmax para obtener probabilidades
confidences = {classes[i]: float(prediction[i]) for i in range(101)} # Crear diccionario de clases y probabilidades
return confidences # Devolver las probabilidades de cada clase
# Crear la interfaz de Gradio
interface = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="numpy"),
outputs=gr.Label(num_top_classes=3),
title="Food101 Classifier",
description="Sube una imagen de comida y el modelo clasificar谩 la imagen.",
examples=["https://www.cnature.es/receta/receta-de-hamburguesa-con-guacamole/"] # Reemplaza con rutas de ejemplo
)
# Iniciar la interfaz
interface.launch(share=True)