File size: 992 Bytes
d6c6696
 
 
 
 
 
 
 
 
 
6732afa
 
 
d6c6696
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import os

import gradio as gr
import PIL
import torch
from dotenv import load_dotenv
from model_inference import ModelInference

# Load secrets and config from optional .env file
load_dotenv()
GLOBAL_MODEL = "global_model_config23_21ep_timm_resnet50.pt"
CATEGORY_MAP = "category_map.json"
CATEG_TO_NAME_MAP = "categ_to_name_map.json"


# Model prediction function
def predict_species(image: PIL.Image.Image) -> dict[str, float]:
    """Moth species prediction"""

    # Build the model class
    device = "cuda" if torch.cuda.is_available() else "cpu"
    fgrained_classifier = ModelInference(
        GLOBAL_MODEL, "timm_resnet50", CATEGORY_MAP, CATEG_TO_NAME_MAP, device, topk=5
    )

    # Predict on image
    sp_pred = fgrained_classifier.predict(image)

    return sp_pred


demo = gr.Interface(
    fn=predict_species,
    inputs=gr.Image(type="pil"),
    outputs=gr.Label(),
    title="Mila Global Moth Species Classifier",
)

if __name__ == "__main__":
    demo.launch(share=True)