Spaces:
Runtime error
Runtime error
File size: 3,500 Bytes
ea9f27f 532251c ea9f27f 532251c ea9f27f 4cbf82a ea9f27f f4ce0ac 4cbf82a f4ce0ac 980de13 4cbf82a ea9f27f 4cbf82a ea9f27f 4cbf82a ea9f27f 4cbf82a ea9f27f 980de13 eb95227 4cbf82a ea9f27f eb95227 4cbf82a ea9f27f eb95227 4cbf82a ea9f27f 4cbf82a ea9f27f 4cbf82a ea9f27f 532251c ea9f27f 4cbf82a ea9f27f 1b147a3 ea9f27f 1b147a3 ea9f27f 4cbf82a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
from seg import U2NETP
from GeoTr import GeoTr
from IllTr import IllTr
from inference_ill import rec_ill
import torch
import torch.nn as nn
import torch.nn.functional as F
import skimage.io as io
import numpy as np
import cv2
import glob
import os
from PIL import Image
import argparse
import warnings
warnings.filterwarnings('ignore')
import gradio as gr
example_img_list = ['51_1 copy.png', '48_2 copy.png', '25.jpg']
def reload_model(model, path=""):
if not bool(path):
return model
else:
model_dict = model.state_dict()
pretrained_dict = torch.load(path, map_location='cpu')
# print(len(pretrained_dict.keys()))
pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
# print(len(pretrained_dict.keys()))
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
return model
def reload_segmodel(model, path=""):
if not bool(path):
return model
else:
model_dict = model.state_dict()
pretrained_dict = torch.load(path, map_location='cpu')
# print(len(pretrained_dict.keys()))
pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict}
# print(len(pretrained_dict.keys()))
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
return model
class GeoTr_Seg(nn.Module):
def __init__(self):
super(GeoTr_Seg, self).__init__()
self.msk = U2NETP(3, 1)
self.GeoTr = GeoTr(num_attn_layers=6)
def forward(self, x):
msk, _1, _2, _3, _4, _5, _6 = self.msk(x)
msk = (msk > 0.5).float()
x = msk * x
bm = self.GeoTr(x)
bm = (2 * (bm / 286.8) - 1) * 0.99
return bm
# Initialize models
GeoTr_Seg_model = GeoTr_Seg()
# IllTr_model = IllTr()
# Load models only once
reload_segmodel(GeoTr_Seg_model.msk, './model_pretrained/seg.pth')
reload_model(GeoTr_Seg_model.GeoTr, './model_pretrained/geotr.pth')
# reload_model(IllTr_model, './model_pretrained/illtr.pth')
# Compile models (assuming PyTorch 2.0)
GeoTr_Seg_model = torch.compile(GeoTr_Seg_model)
# IllTr_model = torch.compile(IllTr_model)
def process_image(input_image):
GeoTr_Seg_model.eval()
# IllTr_model.eval()
im_ori = np.array(input_image)[:, :, :3] / 255.
h, w, _ = im_ori.shape
im = cv2.resize(im_ori, (288, 288))
im = im.transpose(2, 0, 1)
im = torch.from_numpy(im).float().unsqueeze(0)
with torch.no_grad():
bm = GeoTr_Seg_model(im)
bm = bm.cpu()
bm0 = cv2.resize(bm[0, 0].numpy(), (w, h))
bm1 = cv2.resize(bm[0, 1].numpy(), (w, h))
bm0 = cv2.blur(bm0, (3, 3))
bm1 = cv2.blur(bm1, (3, 3))
lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0)
out = F.grid_sample(torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), lbl, align_corners=True)
img_geo = ((out[0] * 255).permute(1, 2, 0).numpy()).astype(np.uint8)
ill_rec = False
if ill_rec:
img_ill = rec_ill(IllTr_model, img_geo)
return Image.fromarray(img_ill)
else:
return Image.fromarray(img_geo)
# Define Gradio interface
input_image = gr.inputs.Image()
output_image = gr.outputs.Image(type='pil')
iface = gr.Interface(fn=process_image, inputs=input_image, outputs=output_image, title="DocTr",
examples=example_img_list)
iface.launch() |