File size: 1,500 Bytes
a63dc67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import torchvision
from PIL import Image
import cv2
import numpy as np
import matplotlib.pyplot as plt

from DataSet import QuestionDataSet
import TractionModel as plup

import random
from tqdm import tqdm

import gradio as gr


def snap(image):
    return np.flipud(image)


def init_model(path):
    model = plup.create_model()
    model = plup.load_weights(model, path)
    model.eval()
    return model


def inference(image):
    image = vanilla_transform(image).to(device).unsqueeze(0)
    with torch.no_grad():
        pred = model(image)
    res = float(torch.sigmoid(pred[1].to("cpu")).numpy()[0])
    return {'pull-up': res, 'no pull-up': 1 - res}


norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
vanilla_transform = torchvision.transforms.Compose([
                                                    torchvision.transforms.Resize(224),
                                                    torchvision.transforms.ToTensor(),
                                                    torchvision.transforms.Normalize(norm_mean, norm_std)])

model = init_model("output/model/model-score0.96-f1_10.9-f1_20.99.pt")
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
model = model.to(device)

iface = gr.Interface(inference, live=True, inputs=gr.inputs.Image(source="upload", tool=None, type='pil'),
                     outputs=gr.outputs.Label())

iface.test_launch()
if __name__ == "__main__":
    iface.launch()