ahmedbrs commited on
Commit
20d921e
·
1 Parent(s): 254fdf2
Files changed (1) hide show
  1. app_old.py +0 -92
app_old.py DELETED
@@ -1,92 +0,0 @@
1
-
2
- import gradio as gr
3
- from PIL import Image
4
- import torch
5
- from torchvision.transforms import InterpolationMode
6
- BICUBIC = InterpolationMode.BICUBIC
7
- from utils import setup, get_similarity_map, display_segmented_sketch
8
- from vpt.launch import default_argument_parser
9
- from collections import OrderedDict
10
- import numpy as np
11
- import matplotlib.pyplot as plt
12
- import models
13
-
14
- args = default_argument_parser().parse_args()
15
- cfg = setup(args)
16
-
17
- device ="cpu"# "cuda" if torch.cuda.is_available() else "cpu"
18
- Ours, preprocess = models.load("CS-ViT-B/16", device=device,cfg=cfg,train_bool=False)
19
- state_dict = torch.load("sketch_seg_best_miou.pth", map_location=device)
20
-
21
- # Trained on 2 gpus so we need to remove the prefix "module." to test it on a single GPU
22
- new_state_dict = OrderedDict()
23
- for k, v in state_dict.items():
24
- name = k[7:] # remove `module.`
25
- new_state_dict[name] = v
26
- Ours.load_state_dict(new_state_dict)
27
- Ours.eval()
28
- print("Model loaded successfully")
29
-
30
- def run(sketch, caption, threshold):
31
-
32
- # set the condidate classes here
33
- classes = [caption]
34
-
35
- colors = plt.get_cmap("tab10").colors
36
- classes_colors = colors[2:len(classes)+2]
37
-
38
- sketch = sketch['composite']
39
- sketch = np.array(sketch)
40
-
41
- pil_img = Image.fromarray(sketch).convert('RGB')
42
- sketch_tensor = preprocess(pil_img).unsqueeze(0).to(device)
43
-
44
- with torch.no_grad():
45
- text_features = models.encode_text_with_prompt_ensemble(Ours, classes, device,no_module=True)
46
- redundant_features = models.encode_text_with_prompt_ensemble(Ours, [""], device,no_module=True)
47
-
48
- num_of_tokens = 3
49
- with torch.no_grad():
50
- sketch_features = Ours.encode_image(sketch_tensor,layers=[12],text_features=text_features-redundant_features,mode="test").squeeze(0)
51
- sketch_features = sketch_features / sketch_features.norm(dim=1, keepdim=True)
52
- similarity = sketch_features @ (text_features - redundant_features).t()
53
- patches_similarity = similarity[0, num_of_tokens +1:, :]
54
- pixel_similarity = get_similarity_map(patches_similarity.unsqueeze(0),pil_img.size).cpu()
55
- # visualize_attention_maps_with_tokens(pixel_similarity, classes)
56
- pixel_similarity[pixel_similarity<threshold] = 0
57
- pixel_similarity_array = pixel_similarity.cpu().numpy().transpose(2,0,1)
58
-
59
- display_segmented_sketch(pixel_similarity_array,sketch,classes,classes_colors,live=True)
60
-
61
- rgb_image = Image.open('output.png')
62
-
63
- return rgb_image
64
-
65
-
66
- css=".gradio-container {background-color: black}"
67
-
68
- demo = gr.Interface(
69
- fn=run,
70
- # js=js,
71
- css=css,
72
- theme="gstaff/sketch", #xkcd
73
- description='Upload a skecth and find objects.'\
74
- ' Check run examples down the page.',
75
- inputs=[
76
- gr.ImageEditor(
77
- label="Sketch", type="pil",sources="upload"),
78
-
79
- gr.Textbox(label="Caption", placeholder="Describe which objects to segment"),
80
- gr.Slider(label="Threshold", value=0.6, step=0.05, minimum=0, maximum=1),
81
- ],
82
- outputs=[gr.Image(type="pil", label="Segmented Sketch") ],
83
- allow_flagging=False,
84
- examples=[
85
- ['demo/sketch_1.png', 'giraffe standing', 0.6],
86
- ['demo/sketch_2.png', 'tree', 0.6],
87
- ['demo/sketch_3.png', 'person', 0.6],
88
- ],
89
- title="Scene Sketch Semantic Segmentation")
90
-
91
- if __name__ == "__main__":
92
- demo.launch()