File size: 2,252 Bytes
8f243be
 
 
 
d304119
eba494a
8f243be
 
 
43de206
8f243be
43de206
8f243be
 
 
 
 
 
 
 
 
d304119
eba494a
 
 
d304119
 
 
 
eba494a
 
 
d304119
 
 
 
eba494a
 
 
d304119
 
 
 
 
 
 
8f243be
 
d304119
 
8f243be
d304119
8f243be
d304119
8f243be
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from inference import Inference
import argparse
import gradio as gr
import glob
from huggingface_hub import hf_hub_download
import os

def parse_option():
    parser = argparse.ArgumentParser('MetaFG Inference script', add_help=False)
    parser.add_argument('--cfg', type=str, metavar="FILE", help='path to config file')
    # easy config modification
    parser.add_argument('--model-path', type=str, help="path to model data")
    parser.add_argument('--img-size', type=int, default=384, help='path to image')
    parser.add_argument('--meta-path', default="meta.txt", type=str, help='path to meta data')
    parser.add_argument('--names-path', default="names_mf2.txt", type=str, help='path to meta data')
    args = parser.parse_args()
    return args

if __name__ == '__main__':
    args = parse_option()

    if not args.model_path:
        model_path = hf_hub_download(repo_id="joshvm/inaturalist_sgd_4k",
                                     filename="ckpt_epoch_35.pth",
                                     token=os.environ["HUGGINGFACE_TOKEN"])
    else:
        model_path = args.model_path

    if not args.cfg:
        model_config = hf_hub_download(repo_id="joshvm/inaturalist_sgd_4k",
                                       filename="MetaFG_2_224.yaml",
                                       token=os.environ["HUGGINGFACE_TOKEN"])
    else:
        model_config = args.cfg

    if not args.names_path:
        names_path = hf_hub_download(repo_id="joshvm/inaturalist_sgd_4k",
                                     filename="names_mf2.txt",
                                     token=os.environ["HUGGINGFACE_TOKEN"])
    else:
        names_path = args.names_path


    model = Inference(config_path=model_config,
                       model_path=model_path,
                       names_path=names_path)
    
    def classify(image):
        preds = model.infer(img_path=image, meta_data_path="meta.txt")
        #confidences = {c: float(preds[i]) for i,c in enumerate(model.classes)}

        return preds
    
    gr.Interface(fn=classify, 
            inputs=gr.Image(shape=(args.img_size, args.img_size), type="pil"),
            outputs=gr.Label(num_top_classes=10),
            examples=glob.glob("./example_images/*")).launch()