import gradio as gr from model.nets import my_model import torch import cv2 import torch.utils.data as data import torchvision.transforms as transforms import PIL from PIL import Image from PIL import ImageFile import math import os import torch.nn.functional as F os.environ["CUDA_VISIBLE_DEVICES"] = "1" device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model1 = my_model(en_feature_num=48, en_inter_num=32, de_feature_num=64, de_inter_num=32, sam_number=1, ).to(device) load_path1 = "./mix.pth" model_state_dict1 = torch.load(load_path1, map_location=device) model1.load_state_dict(model_state_dict1) model2 = my_model(en_feature_num=48, en_inter_num=32, de_feature_num=64, de_inter_num=32, sam_number=1, ).to(device) load_path2 = "./uhdm_checkpoint.pth" model_state_dict2 = torch.load(load_path2, map_location=device) model2.load_state_dict(model_state_dict2) def default_toTensor(img): t_list = [transforms.ToTensor()] composed_transform = transforms.Compose(t_list) return composed_transform(img) def predict1(img): in_img = transforms.ToTensor()(img).to(device).unsqueeze(0) b, c, h, w = in_img.size() # pad image such that the resolution is a multiple of 32 w_pad = (math.ceil(w / 32) * 32 - w) // 2 h_pad = (math.ceil(h / 32) * 32 - h) // 2 in_img = img_pad(in_img, w_r=w_pad, h_r=h_pad) with torch.no_grad(): out_1, out_2, out_3 = model1(in_img) if h_pad != 0: out_1 = out_1[:, :, h_pad:-h_pad, :] if w_pad != 0: out_1 = out_1[:, :, :, w_pad:-w_pad] out_1 = out_1.squeeze(0) out_1 = PIL.Image.fromarray(torch.clamp(out_1 * 255, min=0, max=255 ).byte().permute(1, 2, 0).cpu().numpy()) return out_1 def predict2(img): in_img = transforms.ToTensor()(img).to(device).unsqueeze(0) b, c, h, w = in_img.size() # pad image such that the resolution is a multiple of 32 w_pad = (math.ceil(w / 32) * 32 - w) // 2 h_pad = (math.ceil(h / 32) * 32 - h) // 2 in_img = img_pad(in_img, w_r=w_pad, h_r=h_pad) with torch.no_grad(): out_1, out_2, out_3 = model2(in_img) if h_pad != 0: out_1 = out_1[:, :, h_pad:-h_pad, :] if w_pad != 0: out_1 = out_1[:, :, :, w_pad:-w_pad] out_1 = out_1.squeeze(0) out_1 = PIL.Image.fromarray(torch.clamp(out_1 * 255, min=0, max=255 ).byte().permute(1, 2, 0).cpu().numpy()) return out_1 def img_pad(x, h_r=0, w_r=0): ''' Here the padding values are determined by the average r,g,b values across the training set in FHDMi dataset. For the evaluation on the UHDM, you can also try the commented lines where the mean values are calculated from UHDM training set, yielding similar performance. ''' x1 = F.pad(x[:, 0:1, ...], (w_r, w_r, h_r, h_r), value=0.3827) x2 = F.pad(x[:, 1:2, ...], (w_r, w_r, h_r, h_r), value=0.4141) x3 = F.pad(x[:, 2:3, ...], (w_r, w_r, h_r, h_r), value=0.3912) y = torch.cat([x1, x2, x3], dim=1) return y iface1 = gr.Interface(fn=predict1, inputs=gr.inputs.Image(type="pil"), outputs=gr.inputs.Image(type="pil")) iface2 = gr.Interface(fn=predict2, inputs=gr.inputs.Image(type="pil"), outputs=gr.inputs.Image(type="pil")) iface_all = gr.mix.Parallel( iface1, iface2, examples=['001.jpg', '002.jpg', '003.jpg', '004.jpg', '005.jpg'] ) iface_all.launch()