birdie / app.py
ulichovick's picture
Update app.py
d216ada verified
import gradio as gr
from model import SurfinBird
from torchvision import transforms, io
import csv
import torch
with open("birds.csv", "r") as r:
birds = list(csv.reader(r, delimiter=","))
birds = birds[0]
config = {"num_channels": 3, "hidden_units": 256, "num_classes": 525, "labels": birds}
model = SurfinBird(config=config)
model = SurfinBird.from_pretrained("ulichovick/birdnet")
titulo = "Gimme da bird!"
desc = "Carga la imágen de un ave y el modelo intentará identificar la especie del ave en la foto(limitado a 525 especies) "
def predice(imagen):
usr_img_transform = transforms.Compose([
transforms.Resize(size=(224, 224)),
])
target_image = io.read_image(str(imagen)).type(torch.float32)
target_image = target_image / 255.
target_image = usr_img_transform(target_image)
model.eval()
with torch.inference_mode():
target_image = target_image.unsqueeze(dim=0)
target_image_pred = model(target_image)
target_image_pred_label = torch.argmax(target_image_pred, dim=1)
label = config["labels"][target_image_pred_label.item()]
return str(label)
gr.Interface(
predice,
inputs=gr.Image(label="gimme da bird", type="filepath"),
outputs=["text"],
title=titulo,
description=desc,
).launch()