|
import os |
|
|
|
import gradio as gr |
|
import PIL |
|
import torch |
|
from dotenv import load_dotenv |
|
from model_inference import ModelInference |
|
|
|
|
|
load_dotenv() |
|
GLOBAL_MODEL = "global_model_config23_21ep_timm_resnet50.pt" |
|
CATEGORY_MAP = "category_map.json" |
|
CATEG_TO_NAME_MAP = "categ_to_name_map.json" |
|
|
|
|
|
|
|
def predict_species(image: PIL.Image.Image) -> dict[str, float]: |
|
"""Moth species prediction""" |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
fgrained_classifier = ModelInference( |
|
GLOBAL_MODEL, "timm_resnet50", CATEGORY_MAP, CATEG_TO_NAME_MAP, device, topk=5 |
|
) |
|
|
|
|
|
sp_pred = fgrained_classifier.predict(image) |
|
|
|
return sp_pred |
|
|
|
|
|
demo = gr.Interface( |
|
fn=predict_species, |
|
inputs=gr.Image(type="pil"), |
|
outputs=gr.Label(), |
|
title="Mila Global Moth Species Classifier", |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |
|
|