scene-sketch-seg / app_old.py
ahmedbrs's picture
first
254fdf2
raw
history blame
3.35 kB
import gradio as gr
from PIL import Image
import torch
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
from utils import setup, get_similarity_map, display_segmented_sketch
from vpt.launch import default_argument_parser
from collections import OrderedDict
import numpy as np
import matplotlib.pyplot as plt
import models
args = default_argument_parser().parse_args()
cfg = setup(args)
device ="cpu"# "cuda" if torch.cuda.is_available() else "cpu"
Ours, preprocess = models.load("CS-ViT-B/16", device=device,cfg=cfg,train_bool=False)
state_dict = torch.load("sketch_seg_best_miou.pth", map_location=device)
# Trained on 2 gpus so we need to remove the prefix "module." to test it on a single GPU
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
Ours.load_state_dict(new_state_dict)
Ours.eval()
print("Model loaded successfully")
def run(sketch, caption, threshold):
# set the condidate classes here
classes = [caption]
colors = plt.get_cmap("tab10").colors
classes_colors = colors[2:len(classes)+2]
sketch = sketch['composite']
sketch = np.array(sketch)
pil_img = Image.fromarray(sketch).convert('RGB')
sketch_tensor = preprocess(pil_img).unsqueeze(0).to(device)
with torch.no_grad():
text_features = models.encode_text_with_prompt_ensemble(Ours, classes, device,no_module=True)
redundant_features = models.encode_text_with_prompt_ensemble(Ours, [""], device,no_module=True)
num_of_tokens = 3
with torch.no_grad():
sketch_features = Ours.encode_image(sketch_tensor,layers=[12],text_features=text_features-redundant_features,mode="test").squeeze(0)
sketch_features = sketch_features / sketch_features.norm(dim=1, keepdim=True)
similarity = sketch_features @ (text_features - redundant_features).t()
patches_similarity = similarity[0, num_of_tokens +1:, :]
pixel_similarity = get_similarity_map(patches_similarity.unsqueeze(0),pil_img.size).cpu()
# visualize_attention_maps_with_tokens(pixel_similarity, classes)
pixel_similarity[pixel_similarity<threshold] = 0
pixel_similarity_array = pixel_similarity.cpu().numpy().transpose(2,0,1)
display_segmented_sketch(pixel_similarity_array,sketch,classes,classes_colors,live=True)
rgb_image = Image.open('output.png')
return rgb_image
css=".gradio-container {background-color: black}"
demo = gr.Interface(
fn=run,
# js=js,
css=css,
theme="gstaff/sketch", #xkcd
description='Upload a skecth and find objects.'\
' Check run examples down the page.',
inputs=[
gr.ImageEditor(
label="Sketch", type="pil",sources="upload"),
gr.Textbox(label="Caption", placeholder="Describe which objects to segment"),
gr.Slider(label="Threshold", value=0.6, step=0.05, minimum=0, maximum=1),
],
outputs=[gr.Image(type="pil", label="Segmented Sketch") ],
allow_flagging=False,
examples=[
['demo/sketch_1.png', 'giraffe standing', 0.6],
['demo/sketch_2.png', 'tree', 0.6],
['demo/sketch_3.png', 'person', 0.6],
],
title="Scene Sketch Semantic Segmentation")
if __name__ == "__main__":
demo.launch()