Spaces:
Sleeping
Sleeping
import gradio as gr | |
from PIL import Image | |
import torch | |
from torchvision.transforms import InterpolationMode | |
BICUBIC = InterpolationMode.BICUBIC | |
from utils import setup, get_similarity_map, display_segmented_sketch | |
from vpt.launch import default_argument_parser | |
from collections import OrderedDict | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import models | |
args = default_argument_parser().parse_args() | |
cfg = setup(args) | |
device ="cpu"# "cuda" if torch.cuda.is_available() else "cpu" | |
Ours, preprocess = models.load("CS-ViT-B/16", device=device,cfg=cfg,train_bool=False) | |
state_dict = torch.load("sketch_seg_best_miou.pth", map_location=device) | |
# Trained on 2 gpus so we need to remove the prefix "module." to test it on a single GPU | |
new_state_dict = OrderedDict() | |
for k, v in state_dict.items(): | |
name = k[7:] # remove `module.` | |
new_state_dict[name] = v | |
Ours.load_state_dict(new_state_dict) | |
Ours.eval() | |
print("Model loaded successfully") | |
def run(sketch, caption, threshold): | |
# set the condidate classes here | |
classes = [caption] | |
colors = plt.get_cmap("tab10").colors | |
classes_colors = colors[2:len(classes)+2] | |
sketch = sketch['composite'] | |
sketch = np.array(sketch) | |
pil_img = Image.fromarray(sketch).convert('RGB') | |
sketch_tensor = preprocess(pil_img).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
text_features = models.encode_text_with_prompt_ensemble(Ours, classes, device,no_module=True) | |
redundant_features = models.encode_text_with_prompt_ensemble(Ours, [""], device,no_module=True) | |
num_of_tokens = 3 | |
with torch.no_grad(): | |
sketch_features = Ours.encode_image(sketch_tensor,layers=[12],text_features=text_features-redundant_features,mode="test").squeeze(0) | |
sketch_features = sketch_features / sketch_features.norm(dim=1, keepdim=True) | |
similarity = sketch_features @ (text_features - redundant_features).t() | |
patches_similarity = similarity[0, num_of_tokens +1:, :] | |
pixel_similarity = get_similarity_map(patches_similarity.unsqueeze(0),pil_img.size).cpu() | |
# visualize_attention_maps_with_tokens(pixel_similarity, classes) | |
pixel_similarity[pixel_similarity<threshold] = 0 | |
pixel_similarity_array = pixel_similarity.cpu().numpy().transpose(2,0,1) | |
display_segmented_sketch(pixel_similarity_array,sketch,classes,classes_colors,live=True) | |
rgb_image = Image.open('output.png') | |
return rgb_image | |
css=".gradio-container {background-color: black}" | |
demo = gr.Interface( | |
fn=run, | |
# js=js, | |
css=css, | |
theme="gstaff/sketch", #xkcd | |
description='Upload a skecth and find objects.'\ | |
' Check run examples down the page.', | |
inputs=[ | |
gr.ImageEditor( | |
label="Sketch", type="pil",sources="upload"), | |
gr.Textbox(label="Caption", placeholder="Describe which objects to segment"), | |
gr.Slider(label="Threshold", value=0.6, step=0.05, minimum=0, maximum=1), | |
], | |
outputs=[gr.Image(type="pil", label="Segmented Sketch") ], | |
allow_flagging=False, | |
examples=[ | |
['demo/sketch_1.png', 'giraffe standing', 0.6], | |
['demo/sketch_2.png', 'tree', 0.6], | |
['demo/sketch_3.png', 'person', 0.6], | |
], | |
title="Scene Sketch Semantic Segmentation") | |
if __name__ == "__main__": | |
demo.launch() |