File size: 3,500 Bytes
ea9f27f
 
 
 
 
 
 
 
 
 
 
532251c
ea9f27f
 
532251c
ea9f27f
4cbf82a
ea9f27f
 
f4ce0ac
 
4cbf82a
 
f4ce0ac
980de13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cbf82a
ea9f27f
 
 
 
 
4cbf82a
ea9f27f
4cbf82a
ea9f27f
 
 
 
 
4cbf82a
ea9f27f
 
980de13
eb95227
 
4cbf82a
ea9f27f
eb95227
 
 
4cbf82a
ea9f27f
eb95227
 
4cbf82a
 
 
ea9f27f
 
 
4cbf82a
ea9f27f
 
 
4cbf82a
ea9f27f
 
 
 
532251c
ea9f27f
4cbf82a
 
ea9f27f
 
 
 
 
 
1b147a3
 
ea9f27f
 
 
 
 
 
 
1b147a3
ea9f27f
 
 
 
4cbf82a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from seg import U2NETP
from GeoTr import GeoTr
from IllTr import IllTr
from inference_ill import rec_ill

import torch
import torch.nn as nn
import torch.nn.functional as F
import skimage.io as io
import numpy as np
import cv2
import glob
import os
from PIL import Image
import argparse
import warnings

warnings.filterwarnings('ignore')

import gradio as gr

example_img_list = ['51_1 copy.png', '48_2 copy.png', '25.jpg']


def reload_model(model, path=""):
    if not bool(path):
        return model
    else:
        model_dict = model.state_dict()
        pretrained_dict = torch.load(path, map_location='cpu')
        # print(len(pretrained_dict.keys()))
        pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
        # print(len(pretrained_dict.keys()))
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

        return model


def reload_segmodel(model, path=""):
    if not bool(path):
        return model
    else:
        model_dict = model.state_dict()
        pretrained_dict = torch.load(path, map_location='cpu')
        # print(len(pretrained_dict.keys()))
        pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict}
        # print(len(pretrained_dict.keys()))
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

        return model


class GeoTr_Seg(nn.Module):
    def __init__(self):
        super(GeoTr_Seg, self).__init__()
        self.msk = U2NETP(3, 1)
        self.GeoTr = GeoTr(num_attn_layers=6)

    def forward(self, x):
        msk, _1, _2, _3, _4, _5, _6 = self.msk(x)
        msk = (msk > 0.5).float()
        x = msk * x

        bm = self.GeoTr(x)
        bm = (2 * (bm / 286.8) - 1) * 0.99

        return bm


# Initialize models
GeoTr_Seg_model = GeoTr_Seg()
# IllTr_model = IllTr()

# Load models only once
reload_segmodel(GeoTr_Seg_model.msk, './model_pretrained/seg.pth')
reload_model(GeoTr_Seg_model.GeoTr, './model_pretrained/geotr.pth')
# reload_model(IllTr_model, './model_pretrained/illtr.pth')

# Compile models (assuming PyTorch 2.0)
GeoTr_Seg_model = torch.compile(GeoTr_Seg_model)


# IllTr_model = torch.compile(IllTr_model)

def process_image(input_image):
    GeoTr_Seg_model.eval()
    # IllTr_model.eval()

    im_ori = np.array(input_image)[:, :, :3] / 255.
    h, w, _ = im_ori.shape
    im = cv2.resize(im_ori, (288, 288))
    im = im.transpose(2, 0, 1)
    im = torch.from_numpy(im).float().unsqueeze(0)

    with torch.no_grad():
        bm = GeoTr_Seg_model(im)
        bm = bm.cpu()
        bm0 = cv2.resize(bm[0, 0].numpy(), (w, h))
        bm1 = cv2.resize(bm[0, 1].numpy(), (w, h))
        bm0 = cv2.blur(bm0, (3, 3))
        bm1 = cv2.blur(bm1, (3, 3))
        lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0)

        out = F.grid_sample(torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), lbl, align_corners=True)
        img_geo = ((out[0] * 255).permute(1, 2, 0).numpy()).astype(np.uint8)

        ill_rec = False

        if ill_rec:
            img_ill = rec_ill(IllTr_model, img_geo)
            return Image.fromarray(img_ill)
        else:
            return Image.fromarray(img_geo)


# Define Gradio interface
input_image = gr.inputs.Image()
output_image = gr.outputs.Image(type='pil')

iface = gr.Interface(fn=process_image, inputs=input_image, outputs=output_image, title="DocTr",
                     examples=example_img_list)
iface.launch()