File size: 2,507 Bytes
ea9f27f
 
 
 
 
 
 
 
 
 
 
532251c
ea9f27f
 
532251c
ea9f27f
 
 
f4ce0ac
 
eecbd6f
f4ce0ac
ea9f27f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb95227
 
 
ea9f27f
eb95227
 
 
 
ea9f27f
eb95227
 
 
ea9f27f
 
 
 
 
 
 
 
 
 
 
 
532251c
ea9f27f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eecbd6f
f4ce0ac
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from seg import U2NETP
from GeoTr import GeoTr
from IllTr import IllTr
from inference_ill import rec_ill

import torch
import torch.nn as nn
import torch.nn.functional as F
import skimage.io as io
import numpy as np
import cv2
import glob
import os
from PIL import Image
import argparse
import warnings
warnings.filterwarnings('ignore')

import gradio as gr

example_img_list = ['51_1 copy.png','48_2 copy.png','25.jpg'] 

class GeoTr_Seg(nn.Module):
    def __init__(self):
        super(GeoTr_Seg, self).__init__()
        self.msk = U2NETP(3, 1)
        self.GeoTr = GeoTr(num_attn_layers=6)
        
    def forward(self, x):
        msk, _1,_2,_3,_4,_5,_6 = self.msk(x)
        msk = (msk > 0.5).float()
        x = msk * x

        bm = self.GeoTr(x)
        bm = (2 * (bm / 286.8) - 1) * 0.99
        
        return bm

# Initialize models
GeoTr_Seg_model = GeoTr_Seg()
IllTr_model = IllTr()

# Load models only once
reload_segmodel(GeoTr_Seg_model.msk, './model_pretrained/seg.pth')
reload_model(GeoTr_Seg_model.GeoTr, './model_pretrained/geotr.pth')
reload_model(IllTr_model, './model_pretrained/illtr.pth')

# Compile models (assuming PyTorch 2.0)
GeoTr_Seg_model = torch.compile(GeoTr_Seg_model)
IllTr_model = torch.compile(IllTr_model)

def process_image(input_image):
    GeoTr_Seg_model.eval()
    IllTr_model.eval()

    im_ori = np.array(input_image)[:, :, :3] / 255.
    h, w, _ = im_ori.shape
    im = cv2.resize(im_ori, (288, 288))
    im = im.transpose(2, 0, 1)
    im = torch.from_numpy(im).float().unsqueeze(0)

    with torch.no_grad():
        bm = GeoTr_Seg_model(im)
        bm = bm.cpu()
        bm0 = cv2.resize(bm[0, 0].numpy(), (w, h))
        bm1 = cv2.resize(bm[0, 1].numpy(), (w, h))
        bm0 = cv2.blur(bm0, (3, 3))
        bm1 = cv2.blur(bm1, (3, 3))
        lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0)

        out = F.grid_sample(torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), lbl, align_corners=True)
        img_geo = ((out[0] * 255).permute(1, 2, 0).numpy()).astype(np.uint8)
        
        ill_rec=False

        if ill_rec:
            img_ill = rec_ill(IllTr_model, img_geo)
            return Image.fromarray(img_ill)
        else:
            return Image.fromarray(img_geo)

# Define Gradio interface
input_image = gr.inputs.Image()
output_image = gr.outputs.Image(type='pil')

iface = gr.Interface(fn=process_image, inputs=input_image, outputs=output_image, title="DocTr", examples=example_img_list)
iface.launch()