File size: 3,698 Bytes
b1c6042
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96f646b
 
 
 
 
b1c6042
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
from model.nets import my_model
import torch
import cv2
import torch.utils.data as data
import torchvision.transforms as transforms
import PIL
from PIL import Image
from PIL import ImageFile
import math
import os
import torch.nn.functional as F

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model1 = my_model(en_feature_num=48,
                     en_inter_num=32,
                     de_feature_num=64,
                     de_inter_num=32,
                     sam_number=1,
                     ).to(device)

load_path1 = "./mix.pth"
model_state_dict1 = torch.load(load_path1, map_location=device)
model1.load_state_dict(model_state_dict1)

model2 = my_model(en_feature_num=48,
                     en_inter_num=32,
                     de_feature_num=64,
                     de_inter_num=32,
                     sam_number=1,
                     ).to(device)

load_path2 = "./uhdm_checkpoint.pth"
model_state_dict2 = torch.load(load_path2, map_location=device)
model2.load_state_dict(model_state_dict2)

def default_toTensor(img):
    t_list = [transforms.ToTensor()]
    composed_transform = transforms.Compose(t_list)
    return composed_transform(img)

def predict1(img):
    in_img = transforms.ToTensor()(img).to(device).unsqueeze(0)
    b, c, h, w = in_img.size()
    # pad image such that the resolution is a multiple of 32
    w_pad = (math.ceil(w / 32) * 32 - w) // 2
    h_pad = (math.ceil(h / 32) * 32 - h) // 2
    in_img = img_pad(in_img, w_r=w_pad, h_r=h_pad)
    with torch.no_grad():
        out_1, out_2, out_3 = model1(in_img)
        if h_pad != 0:
            out_1 = out_1[:, :, h_pad:-h_pad, :]
        if w_pad != 0:
            out_1 = out_1[:, :, :, w_pad:-w_pad]
    out_1 = out_1.squeeze(0)
    out_1 = PIL.Image.fromarray(torch.clamp(out_1 * 255, min=0, max=255
    ).byte().permute(1, 2, 0).cpu().numpy())

    return out_1

def predict2(img):
    in_img = transforms.ToTensor()(img).to(device).unsqueeze(0)
    b, c, h, w = in_img.size()
    # pad image such that the resolution is a multiple of 32
    w_pad = (math.ceil(w / 32) * 32 - w) // 2
    h_pad = (math.ceil(h / 32) * 32 - h) // 2
    in_img = img_pad(in_img, w_r=w_pad, h_r=h_pad)
    with torch.no_grad():
        out_1, out_2, out_3 = model2(in_img)
        if h_pad != 0:
            out_1 = out_1[:, :, h_pad:-h_pad, :]
        if w_pad != 0:
            out_1 = out_1[:, :, :, w_pad:-w_pad]
    out_1 = out_1.squeeze(0)
    out_1 = PIL.Image.fromarray(torch.clamp(out_1 * 255, min=0, max=255
    ).byte().permute(1, 2, 0).cpu().numpy())

    return out_1

def img_pad(x, h_r=0, w_r=0):
    '''
    Here the padding values are determined by the average r,g,b values across the training set
    in FHDMi dataset. For the evaluation on the UHDM, you can also try the commented lines where
    the mean values are calculated from UHDM training set, yielding similar performance.
    '''
    x1 = F.pad(x[:, 0:1, ...], (w_r, w_r, h_r, h_r), value=0.3827)
    x2 = F.pad(x[:, 1:2, ...], (w_r, w_r, h_r, h_r), value=0.4141)
    x3 = F.pad(x[:, 2:3, ...], (w_r, w_r, h_r, h_r), value=0.3912)

    y = torch.cat([x1, x2, x3], dim=1)

    return y


iface1 = gr.Interface(fn=predict1,
             inputs=gr.inputs.Image(type="pil"),
             outputs=gr.inputs.Image(type="pil"))

iface2 = gr.Interface(fn=predict2,
             inputs=gr.inputs.Image(type="pil"),
             outputs=gr.inputs.Image(type="pil"))

iface_all = gr.mix.Parallel(
    iface1, iface2,
    examples=['001.jpg',
              '002.jpg',
              '003.jpg',
              '004.jpg',
              '005.jpg']
)

iface_all.launch()