import gradio as gr from model import SurfinBird from torchvision import transforms, io import csv import torch with open("birds.csv", "r") as r: birds = list(csv.reader(r, delimiter=",")) birds = birds[0] config = {"num_channels": 3, "hidden_units": 256, "num_classes": 525, "labels": birds} model = SurfinBird(config=config) model = SurfinBird.from_pretrained("ulichovick/birdnet") titulo = "Gimme da bird!" desc = "Carga la imágen de un ave y el modelo intentará identificar la especie del ave en la foto(limitado a 525 especies) " def predice(imagen): usr_img_transform = transforms.Compose([ transforms.Resize(size=(224, 224)), ]) target_image = io.read_image(str(imagen)).type(torch.float32) target_image = target_image / 255. target_image = usr_img_transform(target_image) model.eval() with torch.inference_mode(): target_image = target_image.unsqueeze(dim=0) target_image_pred = model(target_image) target_image_pred_label = torch.argmax(target_image_pred, dim=1) label = config["labels"][target_image_pred_label.item()] return str(label) gr.Interface( predice, inputs=gr.Image(label="gimme da bird", type="filepath"), outputs=["text"], title=titulo, description=desc, ).launch()