File size: 1,654 Bytes
1f53a4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import torch
import gradio as gr
from lib import create_model
from lib.options import ParamSet, _retrieve_parameter, _dispatch_by_group
from lib.dataloader import ImageMixin


test_weight = './weight_epoch-200_best.pt'
parameter = './parameters.json'

class ImageHandler(ImageMixin):
    def __init__(self, params):
        self.params = params
        self.transform = self._make_transforms()

    def set_image(self, image):
        image = self.transform(image)
        image = {'image': image.unsqueeze(0)}
        return image

def load_parameter(parameter):
    _args = ParamSet()
    params = _retrieve_parameter(parameter)
    for _param, _arg in params.items():
        setattr(_args, _param, _arg)

    _args.augmentation = 'no'
    _args.sampler = 'no'
    _args.pretrained = False
    _args.mlp = None
    _args.net = _args.model
    _args.device = torch.device('cpu')

    args_model = _dispatch_by_group(_args, 'model')
    args_dataloader = _dispatch_by_group(_args, 'dataloader')
    return args_model, args_dataloader

args_model, args_dataloader = load_parameter(parameter)
model = create_model(args_model)
model.load_weight(test_weight)

def main(image):
    model.eval()
    image_handler = ImageHandler(args_dataloader)
    image = image_handler.set_image(image)

    with torch.no_grad():
        outputs = model(image)

    label_name = list(outputs.keys())[0]
    result = outputs[label_name].detach().numpy().item()
    result = f"{result:.2f}"
    return result


# Gradio
iface = gr.Interface(fn=main, inputs=[gr.Image(type='pil', image_mode='L')], outputs='text')
iface.launch()