File size: 2,693 Bytes
7febe9c
 
 
 
 
 
 
d10f17e
a7ce0cb
7febe9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2676e90
1ff61c1
 
7febe9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7de572
a6ceb33
7febe9c
 
726f866
a6ceb33
614027e
726f866
70ae798
 
7febe9c
 
726f866
 
614027e
87d115a
398e876
a6ceb33
7febe9c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
from glob import glob
import cv2
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
import gradio as gr
import spaces

from models.GCoNet import GCoNet


device = ['cpu', 'cuda'][0]


class ImagePreprocessor():
    def __init__(self) -> None:
        self.transform_image = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])

    def proc(self, image):
        image = self.transform_image(image)
        return image


model = GCoNet(bb_pretrained=False).to(device)
state_dict = './ultimate_duts_cocoseg (The best one).pth'
if os.path.exists(state_dict):
    gconet_dict = torch.load(state_dict, map_location=device)
    model.load_state_dict(gconet_dict)
model.eval()


@spaces.GPU
def pred_maps(image_1, image_2, image_3, image_4):
    images = [image_1, image_2, image_3, image_4]
    image_shapes = [image.shape[:2] for image in images]
    images = [Image.fromarray(image) for image in images]

    images_proc = []
    image_preprocessor = ImagePreprocessor()
    for image in images:
        images_proc.append(image_preprocessor.proc(image))
    images_proc = torch.cat([image_proc.unsqueeze(0) for image_proc in images_proc])

    with torch.no_grad():
        scaled_preds_tensor = model(images_proc.to(device))[-1]
    preds = []
    for image_shape, pred_tensor in zip(image_shapes, scaled_preds_tensor):
        if device == 'cuda':
            pred_tensor = pred_tensor.cpu()
        preds.append(torch.nn.functional.interpolate(pred_tensor.unsqueeze(0), size=image_shape, mode='bilinear', align_corners=True).squeeze().numpy())
    image_preds = []
    for image, pred in zip(images, preds):
        image_preds.append(
                np.hstack([np.array(image.convert('RGB')), cv2.cvtColor((pred*255).astype(np.uint8), cv2.COLOR_GRAY2RGB)])
        )
    return image_preds[:]


N = 4
# examples = [[_] for _ in glob('example_images/butterfly/*')][:N]

ipt = [gr.Image(width=600, height=300) for _ in range(N)]
opt = [gr.Image(width=600, height=300) for _ in range(N)]
demo = gr.Interface(
    fn=pred_maps,
    inputs=ipt,
    outputs=opt,
    # examples=examples,
    # interpretation='default',
    title='Online demo for `GCoNet+: A Stronger Group Collaborative Co-Salient Object Detector (T-PAMI 2023)`',
    description='Upload pictures, most of which contain salient objects of the same class. Our demo will give you the binary maps of these co-salient objects :)\n**********Example images need to be dropped into each block, instead of click.**********'
)
demo.launch(debug=True)