import os from glob import glob import cv2 import numpy as np from PIL import Image import torch from torchvision import transforms import gradio as gr from models.GCoNet import GCoNet device = ['cpu', 'cuda'][0] class ImagePreprocessor(): def __init__(self) -> None: self.transform_image = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) def proc(self, image): image = self.transform_image(image) return image model = GCoNet(bb_pretrained=False).to(device) state_dict = './ultimate_duts_cocoseg (The best one).pth' if os.path.exists(state_dict): gconet_dict = torch.load(state_dict, map_location=device) model.load_state_dict(gconet_dict) model.eval() def pred_maps(dr): images = [cv2.imread(image_path) for image_path in glob(os.path.join(dr, '*'))] image_shapes = [image.shape[:2] for image in images] images = [Image.fromarray(image) for image in images] images_proc = [] image_preprocessor = ImagePreprocessor() for image in images: images_proc.append(image_preprocessor.proc(image)) images_proc = torch.cat([image_proc.unsqueeze(0) for image_proc in images_proc]) with torch.no_grad(): scaled_preds_tensor = model(images_proc.to(device))[-1] preds = [] for image_shape, pred_tensor in zip(image_shapes, scaled_preds_tensor): if device == 'cuda': pred_tensor = pred_tensor.cpu() preds.append(torch.nn.functional.interpolate(pred_tensor.unsqueeze(0), size=image_shape, mode='bilinear', align_corners=True).squeeze().numpy()) image_preds = [] for image, pred in zip(images, preds): image_preds.append( cv2.cvtColor( np.hstack([np.array(image.convert('RGB')), cv2.cvtColor((pred*255).astype(np.uint8), cv2.COLOR_GRAY2RGB)]), cv2.COLOR_BGR2RGB )) # for image_pred in image_preds: # cv2.imwrite('a.png', cv2.cvtColor(image_pred, cv2.COLOR_RGB2BGR)) return image_preds[:] examples = glob('example_images/butterfly') ipt = ['image' for _ in range(5)] opt = ipt.copy() demo = gr.Interface( fn=pred_maps, inputs=ipt, outputs=opt, examples=examples, ) demo.launch(debug=True)