File size: 3,354 Bytes
254fdf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92

import gradio as gr
from PIL import Image
import torch
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
from utils import setup, get_similarity_map, display_segmented_sketch
from vpt.launch import default_argument_parser
from collections import OrderedDict
import numpy as np
import matplotlib.pyplot as plt
import models

args = default_argument_parser().parse_args()
cfg = setup(args)

device ="cpu"# "cuda" if torch.cuda.is_available() else "cpu"
Ours, preprocess = models.load("CS-ViT-B/16", device=device,cfg=cfg,train_bool=False)
state_dict = torch.load("sketch_seg_best_miou.pth", map_location=device)

# Trained on 2 gpus so we need to remove the prefix "module." to test it on a single GPU
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v 
Ours.load_state_dict(new_state_dict)     
Ours.eval()     
print("Model loaded successfully")
    
def run(sketch, caption, threshold):
    
    # set the condidate classes here
    classes = [caption] 
    
    colors = plt.get_cmap("tab10").colors
    classes_colors = colors[2:len(classes)+2]

    sketch = sketch['composite']
    sketch = np.array(sketch)
    
    pil_img = Image.fromarray(sketch).convert('RGB')
    sketch_tensor = preprocess(pil_img).unsqueeze(0).to(device)

    with torch.no_grad():
        text_features = models.encode_text_with_prompt_ensemble(Ours, classes, device,no_module=True)
        redundant_features = models.encode_text_with_prompt_ensemble(Ours, [""], device,no_module=True)            

    num_of_tokens = 3
    with torch.no_grad():
        sketch_features = Ours.encode_image(sketch_tensor,layers=[12],text_features=text_features-redundant_features,mode="test").squeeze(0)
        sketch_features = sketch_features / sketch_features.norm(dim=1, keepdim=True)
    similarity = sketch_features @ (text_features - redundant_features).t()
    patches_similarity = similarity[0, num_of_tokens +1:, :]
    pixel_similarity = get_similarity_map(patches_similarity.unsqueeze(0),pil_img.size).cpu()
    # visualize_attention_maps_with_tokens(pixel_similarity, classes)
    pixel_similarity[pixel_similarity<threshold] = 0
    pixel_similarity_array = pixel_similarity.cpu().numpy().transpose(2,0,1)
    
    display_segmented_sketch(pixel_similarity_array,sketch,classes,classes_colors,live=True)
    
    rgb_image = Image.open('output.png')

    return rgb_image


css=".gradio-container {background-color: black}"

demo = gr.Interface(
    fn=run,
    # js=js,
    css=css,
    theme="gstaff/sketch", #xkcd   
    description='Upload a skecth and find objects.'\
                ' Check run examples down the page.',
    inputs=[
        gr.ImageEditor(
            label="Sketch", type="pil",sources="upload"),
        
        gr.Textbox(label="Caption", placeholder="Describe which objects to segment"),
        gr.Slider(label="Threshold", value=0.6, step=0.05, minimum=0, maximum=1),
    ], 
    outputs=[gr.Image(type="pil", label="Segmented Sketch") ],
    allow_flagging=False,
    examples=[
        ['demo/sketch_1.png', 'giraffe standing', 0.6],
        ['demo/sketch_2.png', 'tree', 0.6],
        ['demo/sketch_3.png', 'person', 0.6],
    ],
    title="Scene Sketch Semantic Segmentation")

if __name__ == "__main__":
    demo.launch()