File size: 3,383 Bytes
ea9f27f
 
 
 
 
 
 
 
 
 
 
 
 
532251c
ea9f27f
 
532251c
ea9f27f
 
 
 
 
 
 
f4ce0ac
 
eecbd6f
f4ce0ac
ea9f27f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532251c
 
ea9f27f
532251c
ea9f27f
 
 
 
 
 
 
 
 
 
 
532251c
 
ea9f27f
532251c
ea9f27f
 
 
 
 
 
 
 
 
532251c
ea9f27f
 
 
532251c
ea9f27f
 
 
 
 
 
 
 
 
 
 
 
532251c
ea9f27f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74e9d98
 
ea9f27f
 
 
 
 
eecbd6f
f4ce0ac
ea9f27f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#origin

from seg import U2NETP
from GeoTr import GeoTr
from IllTr import IllTr
from inference_ill import rec_ill

import torch
import torch.nn as nn
import torch.nn.functional as F
import skimage.io as io
import numpy as np
import cv2
import glob
import os
from PIL import Image
import argparse
import warnings
warnings.filterwarnings('ignore')





import gradio as gr

example_img_list = ['51_1 copy.png','48_2 copy.png','25.jpg'] 

class GeoTr_Seg(nn.Module):
    def __init__(self):
        super(GeoTr_Seg, self).__init__()
        self.msk = U2NETP(3, 1)
        self.GeoTr = GeoTr(num_attn_layers=6)
        
    def forward(self, x):
        msk, _1,_2,_3,_4,_5,_6 = self.msk(x)
        msk = (msk > 0.5).float()
        x = msk * x

        bm = self.GeoTr(x)
        bm = (2 * (bm / 286.8) - 1) * 0.99
        
        return bm
        

def reload_model(model, path=""):
    if not bool(path):
        return model
    else:
        model_dict = model.state_dict()
        pretrained_dict = torch.load(path, map_location='cpu')
        #print(len(pretrained_dict.keys()))
        pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
        #print(len(pretrained_dict.keys()))
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

        return model
        

def reload_segmodel(model, path=""):
    if not bool(path):
        return model
    else:
        model_dict = model.state_dict()
        pretrained_dict = torch.load(path, map_location='cpu')
        #print(len(pretrained_dict.keys()))
        pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict}
        #print(len(pretrained_dict.keys()))
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

        return model
        



def process_image(input_image):
    GeoTr_Seg_model = GeoTr_Seg()#.cuda()
    reload_segmodel(GeoTr_Seg_model.msk, './model_pretrained/seg.pth')
    reload_model(GeoTr_Seg_model.GeoTr, './model_pretrained/geotr.pth')

    IllTr_model = IllTr()#.cuda()
    reload_model(IllTr_model, './model_pretrained/illtr.pth')

    GeoTr_Seg_model.eval()
    IllTr_model.eval()

    im_ori = np.array(input_image)[:, :, :3] / 255.
    h, w, _ = im_ori.shape
    im = cv2.resize(im_ori, (288, 288))
    im = im.transpose(2, 0, 1)
    im = torch.from_numpy(im).float().unsqueeze(0)

    with torch.no_grad():
        bm = GeoTr_Seg_model(im)
        bm = bm.cpu()
        bm0 = cv2.resize(bm[0, 0].numpy(), (w, h))
        bm1 = cv2.resize(bm[0, 1].numpy(), (w, h))
        bm0 = cv2.blur(bm0, (3, 3))
        bm1 = cv2.blur(bm1, (3, 3))
        lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0)

        out = F.grid_sample(torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), lbl, align_corners=True)
        img_geo = ((out[0] * 255).permute(1, 2, 0).numpy()).astype(np.uint8)
        
        ill_rec=False

        if ill_rec:
            img_ill = rec_ill(IllTr_model, img_geo)
            return Image.fromarray(img_ill)
        else:
            return Image.fromarray(img_geo)



# Define Gradio interface
input_image = gr.inputs.Image()
output_image = gr.outputs.Image(type='pil')


iface = gr.Interface(fn=process_image, inputs=input_image, outputs=output_image, title="DocTr", examples=example_img_list)
iface.launch()