sneha commited on
Commit
ffecbe9
·
1 Parent(s): cf87235
Files changed (2) hide show
  1. app.py +6 -17
  2. attn_helper.py +0 -2
app.py CHANGED
@@ -8,7 +8,6 @@ import torch
8
  import matplotlib.pyplot as plt
9
  from attn_helper import VITAttentionGradRollout, overlay_attn
10
  import vc_models
11
- #import eaif_models
12
  import torchvision
13
 
14
 
@@ -64,7 +63,7 @@ def download_bin(model):
64
  os.rename(model_bin, bin_path)
65
 
66
 
67
- def run_attn(input_img, model="vc1-large",fusion="min",slider=0):
68
  download_bin(model)
69
  model, embedding_dim, transform, metadata = get_model(model)
70
  if input_img.shape[0] != 3:
@@ -76,32 +75,22 @@ def run_attn(input_img, model="vc1-large",fusion="min",slider=0):
76
  input_img = resize_transform(input_img)
77
  x = transform(input_img)
78
 
79
- attention_rollout = VITAttentionGradRollout(model,head_fusion=fusion,discard_ratio=slider)
80
 
81
  y = model(x)
82
  mask = attention_rollout.get_attn_mask()
83
  attn_img = overlay_attn(input_img[0].permute(1,2,0),mask)
84
-
85
- fig = plt.figure()
86
- ax = fig.subplots()
87
- print(y.shape)
88
- im = ax.matshow(y.detach().numpy().reshape(16,-1))
89
- plt.colorbar(im)
90
-
91
- return attn_img, fig
92
 
93
  model_type = gr.Dropdown(
94
  ["vc1-base", "vc1-large"], label="Model Size", value="vc1-large")
95
  input_img = gr.Image(shape=(250,250))
96
- input_button = gr.Radio(["min", "max", "mean"], value="min",label="Attention Head Fusion", info="How to combine the last layer attention across all 12 heads of the transformer.")
97
  output_img = gr.Image(shape=(250,250))
98
- output_plot = gr.Plot()
99
- css = "#component-3, .input-image, .image-preview {height: 240px !important}"
100
- slider = gr.Slider(0, 1)
101
  markdown ="This is a demo for the Visual Cortex models. When passed an image input, it displays the attention of the last layer of the transformer.\n \
102
  The user can decide how the attention heads will be combined. \
103
  Along with the attention heatmap, it also displays the embedding values reshaped to a 16x48 or 16x64 grid."
104
  demo = gr.Interface(fn=run_attn, title="Visual Cortex Large Model", description=markdown,
105
- examples=[[os.path.join('./imgs',x),None,None,None]for x in os.listdir(os.path.join(os.getcwd(),'imgs')) if 'jpg' in x],
106
- inputs=[input_img,model_type,input_button,slider],outputs=[output_img,output_plot],css=css)
107
  demo.launch()
 
8
  import matplotlib.pyplot as plt
9
  from attn_helper import VITAttentionGradRollout, overlay_attn
10
  import vc_models
 
11
  import torchvision
12
 
13
 
 
63
  os.rename(model_bin, bin_path)
64
 
65
 
66
+ def run_attn(input_img, model="vc1-large"):
67
  download_bin(model)
68
  model, embedding_dim, transform, metadata = get_model(model)
69
  if input_img.shape[0] != 3:
 
75
  input_img = resize_transform(input_img)
76
  x = transform(input_img)
77
 
78
+ attention_rollout = VITAttentionGradRollout(model,head_fusion="max",discard_ratio=0.89)
79
 
80
  y = model(x)
81
  mask = attention_rollout.get_attn_mask()
82
  attn_img = overlay_attn(input_img[0].permute(1,2,0),mask)
83
+ return attn_img
 
 
 
 
 
 
 
84
 
85
  model_type = gr.Dropdown(
86
  ["vc1-base", "vc1-large"], label="Model Size", value="vc1-large")
87
  input_img = gr.Image(shape=(250,250))
 
88
  output_img = gr.Image(shape=(250,250))
89
+ css = "#component-2, .input-image, .image-preview {height: 240px !important}"
 
 
90
  markdown ="This is a demo for the Visual Cortex models. When passed an image input, it displays the attention of the last layer of the transformer.\n \
91
  The user can decide how the attention heads will be combined. \
92
  Along with the attention heatmap, it also displays the embedding values reshaped to a 16x48 or 16x64 grid."
93
  demo = gr.Interface(fn=run_attn, title="Visual Cortex Large Model", description=markdown,
94
+ examples=[[os.path.join('./imgs',x),None]for x in os.listdir(os.path.join(os.getcwd(),'imgs')) if 'jpg' in x],
95
+ inputs=[input_img,model_type],outputs=output_img,css=css)
96
  demo.launch()
attn_helper.py CHANGED
@@ -18,7 +18,6 @@ def overlay_attn(original_image,mask):
18
  # Apply colormap to mask
19
  cmap = cv2.applyColorMap(np.uint8(255 * mask), colormap_attn)
20
 
21
- print(cmap.shape)
22
  # Blend mask and original image
23
  # grayscale_img = cv2.cvtColor(np.uint8(original_image), cv2.COLOR_RGB2GRAY)
24
  # grayscale_img = cv2.cvtColor(grayscale_img, cv2.COLOR_GRAY2RGB)
@@ -45,7 +44,6 @@ class VITAttentionGradRollout:
45
  self.model = model
46
  self.head_fusion = head_fusion
47
  self.discard_ratio = discard_ratio
48
- print(list(model.blocks.children()))
49
 
50
  self.attentions = {}
51
  for idx, module in enumerate(list(model.blocks.children())):
 
18
  # Apply colormap to mask
19
  cmap = cv2.applyColorMap(np.uint8(255 * mask), colormap_attn)
20
 
 
21
  # Blend mask and original image
22
  # grayscale_img = cv2.cvtColor(np.uint8(original_image), cv2.COLOR_RGB2GRAY)
23
  # grayscale_img = cv2.cvtColor(grayscale_img, cv2.COLOR_GRAY2RGB)
 
44
  self.model = model
45
  self.head_fusion = head_fusion
46
  self.discard_ratio = discard_ratio
 
47
 
48
  self.attentions = {}
49
  for idx, module in enumerate(list(model.blocks.children())):