sneha commited on
Commit
cf37148
·
1 Parent(s): 9ee651f

first commit

Browse files
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
- title: Vc1 Large
3
- emoji: 👁
4
- colorFrom: purple
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 3.24.1
8
  app_file: app.py
9
  pinned: false
10
- license: cc-by-4.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Visual Cortex Demo
3
+ emoji: 🏢
4
+ colorFrom: red
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 3.23.0
8
  app_file: app.py
9
  pinned: false
10
+ license: cc-by-nc-4.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ from huggingface_hub import hf_hub_download
4
+ import omegaconf
5
+ from hydra import utils
6
+ import os
7
+ import torch
8
+ import matplotlib.pyplot as plt
9
+ from attn_helper import VITAttentionGradRollout, overlay_attn
10
+ import vc_models
11
+ #import eaif_models
12
+ import torchvision
13
+
14
+
15
+ HF_TOKEN = os.environ['HF_ACC_TOKEN']
16
+ eai_filepath = vc_models.__file__.split('src')[0]
17
+ MODEL_DIR=os.path.join(os.path.dirname(eai_filepath),'model_ckpts')
18
+ if not os.path.isdir(MODEL_DIR):
19
+ os.mkdir(MODEL_DIR)
20
+
21
+ FILENAME = "config.yaml"
22
+ BASE_MODEL_TUPLE = None
23
+ LARGE_MODEL_TUPLE = None
24
+ def get_model(model_name):
25
+ global BASE_MODEL_TUPLE,LARGE_MODEL_TUPLE
26
+ download_bin(model_name)
27
+ model = None
28
+ if BASE_MODEL_TUPLE is None and model_name == 'vc1-base':
29
+ repo_name = "facebook/" + model_name
30
+ model_cfg = omegaconf.OmegaConf.load(
31
+ hf_hub_download(repo_id=repo_name, filename=FILENAME,token=HF_TOKEN)
32
+ )
33
+ BASE_MODEL_TUPLE = utils.instantiate(model_cfg)
34
+ BASE_MODEL_TUPLE[0].eval()
35
+ model = BASE_MODEL_TUPLE
36
+ elif LARGE_MODEL_TUPLE is None and model_name == 'vc1-large':
37
+ repo_name = "facebook/" + model_name
38
+ model_cfg = omegaconf.OmegaConf.load(
39
+ hf_hub_download(repo_id=repo_name, filename=FILENAME,token=HF_TOKEN)
40
+ )
41
+ LARGE_MODEL_TUPLE = utils.instantiate(model_cfg)
42
+ LARGE_MODEL_TUPLE[0].eval()
43
+ model = LARGE_MODEL_TUPLE
44
+ elif model_name == 'vc1-base':
45
+ model = BASE_MODEL_TUPLE
46
+ elif model_name == 'vc1-large':
47
+ model = LARGE_MODEL_TUPLE
48
+
49
+ return model
50
+
51
+ def download_bin(model):
52
+ bin_file = ""
53
+ if model == "vc1-large":
54
+ bin_file = 'vc1_vitl.pth'
55
+ elif model == "vc1-base":
56
+ bin_file = 'vc1_vitb.pth'
57
+ else:
58
+ raise NameError("model not found: " + model)
59
+
60
+ bin_path = os.path.join(MODEL_DIR,bin_file)
61
+ if not os.path.isfile(bin_path):
62
+ model_bin = hf_hub_download(repo_id=REPO_ID, filename='pytorch_model.bin',local_dir=MODEL_DIR,local_dir_use_symlinks=True,token=HF_TOKEN)
63
+ os.rename(model_bin, bin_path)
64
+
65
+
66
+ def run_attn(input_img, model="vc1-large",fusion="min"):
67
+ download_bin(model)
68
+ model, embedding_dim, transform, metadata = get_model(model)
69
+ if input_img.shape[0] != 3:
70
+ input_img = input_img.transpose(2, 0, 1)
71
+ if(len(input_img.shape)== 3):
72
+ input_img = torch.tensor(input_img).unsqueeze(0)
73
+ input_img = input_img.float()
74
+ resize_transform = torchvision.transforms.Resize((250,250))
75
+ input_img = resize_transform(input_img)
76
+ x = transform(input_img)
77
+
78
+ attention_rollout = VITAttentionGradRollout(model,head_fusion=fusion)
79
+
80
+ y = model(x)
81
+ mask = attention_rollout.get_attn_mask()
82
+ attn_img = overlay_attn(input_img[0].permute(1,2,0),mask)
83
+
84
+ fig = plt.figure()
85
+ ax = fig.subplots()
86
+ print(y.shape)
87
+ im = ax.matshow(y.detach().numpy().reshape(16,-1))
88
+ plt.colorbar(im)
89
+
90
+ return attn_img, fig
91
+
92
+ model_type = gr.Dropdown(
93
+ ["vc1-base", "vc1-large"], label="Model Size", value="vc1-large")
94
+ input_img = gr.Image(shape=(250,250))
95
+ input_button = gr.Radio(["min", "max", "mean"], value="min",label="Attention Head Fusion", info="How to combine the last layer attention across all 12 heads of the transformer.")
96
+ output_img = gr.Image(shape=(250,250))
97
+ output_plot = gr.Plot()
98
+
99
+ markdown ="This is a demo for the Visual Cortex models. When passed an image input, it displays the attention of the last layer of the transformer.\n \
100
+ The user can decide how the attention heads will be combined. \
101
+ Along with the attention heatmap, it also displays the embedding values reshaped to a 16x48 or 16x64 grid."
102
+ demo = gr.Interface(fn=run_attn, title="Visual Cortex Large Model", description=markdown,
103
+ examples=[[os.path.join('./imgs',x),None,None]for x in os.listdir(os.path.join(os.getcwd(),'imgs')) if 'jpg' in x],
104
+ inputs=[input_img,model_type,input_button],outputs=[output_img,output_plot])
105
+ demo.launch()
attn_helper.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from PIL import Image
3
+ import numpy as np
4
+ import torch
5
+
6
+ import PIL
7
+
8
+ def overlay_attn(original_image,mask):
9
+ # Colormap and alpha for attention mask
10
+ # COLORMAP_OCEAN
11
+ # COLORMAP_OCEAN
12
+ colormap_attn, alpha_attn = cv2.COLORMAP_JET, 1 #0.85
13
+
14
+ # Resize mask to original image size
15
+ w, h = original_image.shape[0], original_image.shape[1]
16
+ mask = cv2.resize(mask / mask.max(), (h, w))[..., np.newaxis]
17
+
18
+ # Apply colormap to mask
19
+ cmap = cv2.applyColorMap(np.uint8(255 * mask), colormap_attn)
20
+
21
+ print(cmap.shape)
22
+ # Blend mask and original image
23
+ # grayscale_img = cv2.cvtColor(np.uint8(original_image), cv2.COLOR_RGB2GRAY)
24
+ # grayscale_img = cv2.cvtColor(grayscale_img, cv2.COLOR_GRAY2RGB)
25
+ # alpha_blended = cv2.addWeighted(np.uint8(original_image),1, cmap, alpha_attn, 0)
26
+ alpha_blended = cv2.addWeighted(np.uint8(original_image),0.1, cmap, 0.9, 0)
27
+
28
+
29
+ # alpha_blended = cmap
30
+
31
+
32
+ # Save image
33
+ final_im = Image.fromarray(alpha_blended)
34
+ # final_im = final_im.crop((0,0,250,250))
35
+ return final_im
36
+
37
+
38
+
39
+ class VITAttentionGradRollout:
40
+ '''
41
+ Expects timm ViT transformer model
42
+ Adapted from https://github.com/samiraabnar/attention_flow
43
+ '''
44
+ def __init__(self, model, head_fusion='min', discard_ratio=0):
45
+ self.model = model
46
+ self.head_fusion = head_fusion
47
+ self.discard_ratio = discard_ratio
48
+ print(list(model.blocks.children()))
49
+
50
+ self.attentions = {}
51
+ for idx, module in enumerate(list(model.blocks.children())):
52
+ module.attn.register_forward_hook(self.get_attention(f"attn{idx}"))
53
+
54
+
55
+ def get_attention(self, name):
56
+ def hook(module, input, output):
57
+ with torch.no_grad():
58
+ input = input[0]
59
+ B, N, C = input.shape
60
+ qkv = (
61
+ module.qkv(input)
62
+ .detach()
63
+ .reshape(B, N, 3, module.num_heads, C // module.num_heads)
64
+ .permute(2, 0, 3, 1, 4)
65
+ )
66
+ q, k, _ = (
67
+ qkv[0],
68
+ qkv[1],
69
+ qkv[2],
70
+ ) # make torchscript happy (cannot use tensor as tuple)
71
+ attn = (q @ k.transpose(-2, -1)) * module.scale
72
+ attn = attn.softmax(dim=-1)
73
+ self.attentions[name] = attn
74
+ return hook
75
+
76
+ def get_attn_mask(self,k=0):
77
+ attn_key = "attn" + str()
78
+ result = torch.eye(self.attentions['attn0'].size(-1)).to(self.attentions['attn0'].device)
79
+
80
+ # result = torch.eye(self.attentions['attn2'].size(-1)).to(self.attentions['attn2'].device)
81
+ with torch.no_grad():
82
+ # for attention in self.attentions.values():
83
+ for k in range(11, len(self.attentions.keys())):
84
+ attention = self.attentions[f'attn{k}']
85
+ if self.head_fusion == "mean":
86
+ attention_heads_fused = attention.mean(axis=1)
87
+ elif self.head_fusion == "max":
88
+ attention_heads_fused = attention.max(axis=1)[0]
89
+ elif self.head_fusion == "min":
90
+ attention_heads_fused = attention.min(axis=1)[0]
91
+ else:
92
+ raise "Attention head fusion type Not supported"
93
+
94
+ # Drop the lowest attentions, but
95
+ # don't drop the class token
96
+ flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
97
+ _, indices = flat.topk(int(flat.size(-1)*self.discard_ratio), -1, False)
98
+ indices = indices[indices != 0]
99
+ flat[0, indices] = 0
100
+ I = torch.eye(attention_heads_fused.size(-1)).to(attention_heads_fused.device)
101
+ a = (attention_heads_fused + 1.0*I)/2
102
+ a = a / a.sum(dim=-1).unsqueeze(-1)
103
+
104
+ result = torch.matmul(a, result)
105
+
106
+ # Look at the total attention between the class token,
107
+ # and the image patches
108
+ mask = result[0, 0 , 1 :]
109
+ # In case of 224x224 image, this brings us from 196 to 14
110
+ width = int(mask.size(-1)**0.5)
111
+ mask = mask.reshape(width, width).detach().cpu().numpy()
112
+ mask = mask / np.max(mask)
113
+ return mask
imgs/adroit1.jpg ADDED
imgs/cheetah.jpg ADDED
imgs/ego4d.jpg ADDED
imgs/ego4d_2.jpg ADDED
imgs/ego4d_3.jpg ADDED
imgs/kitchen.jpg ADDED
imgs/reacher.jpg ADDED
imgs/rearrange.jpg ADDED
imgs/trifinger1.jpg ADDED
imgs/walker.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ omegaconf
2
+ pillow
3
+ opencv-python
4
+ torch
5
+ numpy
6
+ hydra-core
7
+ gradio
8
+ huggingface_hub
9
+ matplotlib
10
+ git+https://github.com/facebookresearch/eai-vc.git@main#subdirectory=vc_models