Erick Garcia Espinosa
Add application file and dependencies
6b394c3
raw
history blame
1.88 kB
import gradio as gr
import torch
import numpy as np
from torchvision import transforms
from PIL import Image
from timm import create_model
# Definir el diccionario de mapeo de clases a 铆ndices
class_to_idx = {'Monkeypox': 0, 'Melanoma': 1, 'Herpes': 2, 'Sarampion': 3, 'Varicela': 4}
# Transformaci贸n de datos
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# Funci贸n para cargar y preprocesar una imagen
def load_image(image_path):
image = Image.open(image_path).convert('RGB')
image = transform(image).unsqueeze(0) # A帽adir dimensi贸n del batch
return image
# Cargar el modelo
model_name = 'vit_base_patch16_224'
pretrained = True
num_classes = len(class_to_idx)
model = create_model(model_name, pretrained=pretrained, num_classes=num_classes)
model.load_state_dict(torch.load('ARTmodelo5ns_vit_weights_epoch6.pth', map_location='cpu', weights_only=True))
model.eval()
# Definir la funci贸n de predicci贸n
def predict_image(img):
# Convertir la imagen a PIL.Image si es un numpy array
if isinstance(img, np.ndarray):
img = Image.fromarray(img)
# Convertir la imagen a tensor y a帽adir dimensi贸n del batch
img_tensor = transform(img).unsqueeze(0)
# Realizar la predicci贸n
with torch.no_grad():
output = model(img_tensor)
_, predicted = torch.max(output, 1)
predicted_label = list(class_to_idx.keys())[predicted.item()]
return predicted_label
# Crear la interfaz de Gradio
iface = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="filepath", label="Sube una imagen"),
outputs=gr.Label(label="Predicci贸n"),
title="Clasificaci贸n de Im谩genes de Lesiones Cut谩neas",
description="Carga una imagen de una lesi贸n cut谩nea para obtener una predicci贸n."
)
# Lanzar la interfaz de Gradio
iface.launch()