File size: 2,851 Bytes
240c20c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7546a7
 
 
 
 
240c20c
b7546a7
 
240c20c
b7546a7
 
 
240c20c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7546a7
240c20c
 
 
 
 
 
 
 
 
 
b79aac9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
import os
from PIL import Image
import warnings
import gradio as gr

from model import DocGeoNet
from seg import U2NETP
import glob

warnings.filterwarnings('ignore')

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.msk = U2NETP(3, 1)
        self.DocTr = DocGeoNet()

    def forward(self, x):
        msk, _1,_2,_3,_4,_5,_6 = self.msk(x)
        msk = (msk > 0.5).float()
        x = msk * x

        _, _, bm = self.DocTr(x)
        bm = (2 * (bm / 255.) - 1) * 0.99

        return bm

def reload_seg_model(model, path=""):
    if not bool(path):
        return model
    else:
        model_dict = model.state_dict()
        pretrained_dict = torch.load(path, map_location='cpu')
        pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict}
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
        return model

def reload_rec_model(model, path=""):
    if not bool(path):
        return model
    else:
        model_dict = model.state_dict()
        pretrained_dict = torch.load(path, map_location='cpu')
        pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
        return model

net = Net()
seg_model_path = './model_pretrained/preprocess.pth'
rec_model_path = './model_pretrained/DocGeoNet.pth'
reload_rec_model(net.DocTr, rec_model_path)
reload_seg_model(net.msk, seg_model_path)

# Compile models (assuming PyTorch 2.0)
net = torch.compile(net)

net.eval()

def rec(input_image):
    im_ori = np.array(input_image)[:, :, :3] / 255.  # read image 0-255 to 0-1
    h, w, _ = im_ori.shape
    im = cv2.resize(im_ori, (256, 256))
    im = im.transpose(2, 0, 1)
    im = torch.from_numpy(im).float().unsqueeze(0)

    with torch.no_grad():
        bm = net(im)
        bm = bm.cpu()

        bm0 = cv2.resize(bm[0, 0].numpy(), (w, h))  # x flow
        bm1 = cv2.resize(bm[0, 1].numpy(), (w, h))  # y flow
        bm0 = cv2.blur(bm0, (3, 3))
        bm1 = cv2.blur(bm1, (3, 3))
        lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0)  # h * w * 2
        out = F.grid_sample(torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), lbl, align_corners=True)
        img_rec = ((out[0] * 255).permute(1, 2, 0).numpy())[:, :, ::-1].astype(np.uint8)

        # Convert from BGR to RGB
        img_rec = cv2.cvtColor(img_rec, cv2.COLOR_BGR2RGB)
        return Image.fromarray(img_rec)


demo_img_files = glob.glob('./distorted/*.[jJ][pP][gG]') + glob.glob('./distorted/*.[pP][nN][gG]')

# Gradio Interface
input_image = gr.inputs.Image()
output_image = gr.outputs.Image(type='pil')