Spaces:
Sleeping
Sleeping
File size: 1,276 Bytes
421ae0c 8776fa5 6aa742a a3214e3 d23ca88 6aa742a dd43163 6aa742a 3f62def d216ada 6aa742a 728771a 7241802 dd43163 7241802 728771a dd43163 6aa742a d4e1cba f07397f dd43163 421ae0c 6aa742a 03e7f9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import gradio as gr
from model import SurfinBird
from torchvision import transforms, io
import csv
import torch
with open("birds.csv", "r") as r:
birds = list(csv.reader(r, delimiter=","))
birds = birds[0]
config = {"num_channels": 3, "hidden_units": 256, "num_classes": 525, "labels": birds}
model = SurfinBird(config=config)
model = SurfinBird.from_pretrained("ulichovick/birdnet")
titulo = "Gimme da bird!"
desc = "Carga la imágen de un ave y el modelo intentará identificar la especie del ave en la foto(limitado a 525 especies) "
def predice(imagen):
usr_img_transform = transforms.Compose([
transforms.Resize(size=(224, 224)),
])
target_image = io.read_image(str(imagen)).type(torch.float32)
target_image = target_image / 255.
target_image = usr_img_transform(target_image)
model.eval()
with torch.inference_mode():
target_image = target_image.unsqueeze(dim=0)
target_image_pred = model(target_image)
target_image_pred_label = torch.argmax(target_image_pred, dim=1)
label = config["labels"][target_image_pred_label.item()]
return str(label)
gr.Interface(
predice,
inputs=gr.Image(label="gimme da bird", type="filepath"),
outputs=["text"],
title=titulo,
description=desc,
).launch() |