File size: 1,555 Bytes
eab14d0
7726ad3
 
 
 
eab14d0
1269270
7726ad3
 
 
 
 
 
 
 
 
1269270
7726ad3
1269270
7726ad3
 
1269270
7726ad3
1269270
7726ad3
 
 
 
 
 
 
1269270
7726ad3
1269270
 
 
 
 
 
 
 
 
 
 
62caad2
bf2c5dc
 
 
 
8a11eb6
1269270
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import gradio as gr
import torch
import torchvision.transforms as transforms
import torch.nn as nn
import numpy as np

def predict(model, image, device):
	# Preprocess the image
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    img_tensor = transform(image).unsqueeze(0)

    pred_fun = torch.nn.Softmax(dim=1)
    # preds = []
    with torch.set_grad_enabled(False):
        y = pred_fun(model(img_tensor))
        print(y)
    y = y.cpu().numpy()
    print(y)
    y = y[:, 1]  # cat:0, dog: 1
    print(y)
    y = y[0]
    print(y)
    # preds.append(y)
    # preds = np.concatenate(preds)
    return {"tenka ippin":y, "no entry":1-y}
    # return preds

def process_image(input_image):
	model = torch.load('models/tenichi_noentry.pth')
	preds = predict(model, input_image, 'cpu')
	return preds


iface = gr.Interface(
    fn=process_image,
    inputs=[
        gr.Image(type="pil", label="Input Image", height=512),
    ],
    outputs=gr.Label(label="Output", show_label=False),
    description="画像に映っているのが天下一品のロゴなのか、進入禁止標識なのか判別します",
    examples=[
        ["examples/ten20.png"],
        ["examples/noe33.png"],
    ],
    # run_on_click=False,
    # cache_examples=True
)

if __name__ == "__main__":
    iface.launch()

# demo = gr.Interface(fn=greet, inputs="text", outputs="text")
# demo.launch()