Spaces:
Sleeping
Sleeping
import gradio as gr | |
from model import SurfinBird | |
from torchvision import transforms, io | |
import csv | |
import torch | |
with open("birds.csv", "r") as r: | |
birds = list(csv.reader(r, delimiter=",")) | |
birds = birds[0] | |
config = {"num_channels": 3, "hidden_units": 256, "num_classes": 525, "labels": birds} | |
model = SurfinBird(config=config) | |
model = SurfinBird.from_pretrained("ulichovick/birdnet") | |
titulo = "Gimme da bird!" | |
desc = "Carga la imágen de un ave y el modelo intentará identificar la especie del ave en la foto(limitado a 525 especies) " | |
def predice(imagen): | |
usr_img_transform = transforms.Compose([ | |
transforms.Resize(size=(224, 224)), | |
]) | |
target_image = io.read_image(str(imagen)).type(torch.float32) | |
target_image = target_image / 255. | |
target_image = usr_img_transform(target_image) | |
model.eval() | |
with torch.inference_mode(): | |
target_image = target_image.unsqueeze(dim=0) | |
target_image_pred = model(target_image) | |
target_image_pred_label = torch.argmax(target_image_pred, dim=1) | |
label = config["labels"][target_image_pred_label.item()] | |
return str(label) | |
gr.Interface( | |
predice, | |
inputs=gr.Image(label="gimme da bird", type="filepath"), | |
outputs=["text"], | |
title=titulo, | |
description=desc, | |
).launch() |