Mila_Global_Moth_Classifier / gradio_demo.py
adityajain07's picture
Upload folder using huggingface_hub
d6c6696 verified
raw
history blame
989 Bytes
import os
import gradio as gr
import PIL
import torch
from dotenv import load_dotenv
from model_inference import ModelInference
# Load secrets and config from optional .env file
load_dotenv()
GLOBAL_MODEL = os.getenv("GLOBAL_MODEL")
CATEGORY_MAP = os.getenv("CATEGORY_MAP_JSON")
CATEG_TO_NAME_MAP = os.getenv("CATEG_TO_NAME_MAP")
# Model prediction function
def predict_species(image: PIL.Image.Image) -> dict[str, float]:
"""Moth species prediction"""
# Build the model class
device = "cuda" if torch.cuda.is_available() else "cpu"
fgrained_classifier = ModelInference(
GLOBAL_MODEL, "timm_resnet50", CATEGORY_MAP, CATEG_TO_NAME_MAP, device, topk=5
)
# Predict on image
sp_pred = fgrained_classifier.predict(image)
return sp_pred
demo = gr.Interface(
fn=predict_species,
inputs=gr.Image(type="pil"),
outputs=gr.Label(),
title="Mila Global Moth Species Classifier",
)
if __name__ == "__main__":
demo.launch(share=True)