ZhengPeng7's picture
Polish up the interface. Correct the specified examples.
398e876
raw
history blame
2.66 kB
import os
from glob import glob
import cv2
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
import gradio as gr
from models.GCoNet import GCoNet
device = ['cpu', 'cuda'][0]
class ImagePreprocessor():
def __init__(self) -> None:
self.transform_image = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
def proc(self, image):
image = self.transform_image(image)
return image
model = GCoNet(bb_pretrained=False).to(device)
state_dict = './ultimate_duts_cocoseg (The best one).pth'
if os.path.exists(state_dict):
gconet_dict = torch.load(state_dict, map_location=device)
model.load_state_dict(gconet_dict)
model.eval()
def pred_maps(dr):
images = [cv2.imread(image_path) for image_path in glob(os.path.join(dr, '*'))]
image_shapes = [image.shape[:2] for image in images]
images = [Image.fromarray(image) for image in images]
images_proc = []
image_preprocessor = ImagePreprocessor()
for image in images:
images_proc.append(image_preprocessor.proc(image))
images_proc = torch.cat([image_proc.unsqueeze(0) for image_proc in images_proc])
with torch.no_grad():
scaled_preds_tensor = model(images_proc.to(device))[-1]
preds = []
for image_shape, pred_tensor in zip(image_shapes, scaled_preds_tensor):
if device == 'cuda':
pred_tensor = pred_tensor.cpu()
preds.append(torch.nn.functional.interpolate(pred_tensor.unsqueeze(0), size=image_shape, mode='bilinear', align_corners=True).squeeze().numpy())
image_preds = []
for image, pred in zip(images, preds):
image_preds.append(
cv2.cvtColor(
np.hstack([np.array(image.convert('RGB')), cv2.cvtColor((pred*255).astype(np.uint8), cv2.COLOR_GRAY2RGB)]),
cv2.COLOR_BGR2RGB
))
# for image_pred in image_preds:
# cv2.imwrite('a.png', cv2.cvtColor(image_pred, cv2.COLOR_RGB2BGR))
return image_preds[:]
examples = glob('example_images/butterfly/*')
ipt = ['image' for _ in range(5)]
opt = ipt.copy()
demo = gr.Interface(
fn=pred_maps,
inputs=ipt,
outputs=opt,
examples=examples,
interpretation='default',
title='Online demo for `GCoNet+: A Stronger Group Collaborative Co-Salient Object Detector (T-PAMI 2023)`',
description='Upload pictures, most of which contain salient objects of the same class.\nOur demo will give you the binary maps of these co-salient objects :)'
)
demo.launch(debug=True)