Tingfan silwals commited on
Commit
1487b68
·
0 Parent(s):

Duplicate from facebook/vc1-large

Browse files

Co-authored-by: sneha <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Visual Cortex Demo
3
+ emoji: 🏢
4
+ colorFrom: red
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.23.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: cc-by-nc-4.0
11
+ duplicated_from: facebook/vc1-large
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ from huggingface_hub import hf_hub_download
4
+ import omegaconf
5
+ from hydra import utils
6
+ import os
7
+ import torch
8
+ import matplotlib.pyplot as plt
9
+ from attn_helper import VITAttentionGradRollout, overlay_attn
10
+ import vc_models
11
+ import torchvision
12
+
13
+
14
+ HF_TOKEN = os.environ['HF_ACC_TOKEN']
15
+ eai_filepath = vc_models.__file__.split('src')[0]
16
+ MODEL_DIR=os.path.join(os.path.dirname(eai_filepath),'model_ckpts')
17
+ if not os.path.isdir(MODEL_DIR):
18
+ os.mkdir(MODEL_DIR)
19
+
20
+ FILENAME = "config.yaml"
21
+ BASE_MODEL_TUPLE = None
22
+ LARGE_MODEL_TUPLE = None
23
+ def get_model(model_name):
24
+ global BASE_MODEL_TUPLE,LARGE_MODEL_TUPLE
25
+ download_bin(model_name)
26
+ model = None
27
+ if BASE_MODEL_TUPLE is None and model_name == 'vc1-base':
28
+ repo_name = "facebook/" + model_name
29
+ model_cfg = omegaconf.OmegaConf.load(
30
+ hf_hub_download(repo_id=repo_name, filename=FILENAME,token=HF_TOKEN)
31
+ )
32
+ BASE_MODEL_TUPLE = utils.instantiate(model_cfg)
33
+ BASE_MODEL_TUPLE[0].eval()
34
+ model = BASE_MODEL_TUPLE
35
+ elif LARGE_MODEL_TUPLE is None and model_name == 'vc1-large':
36
+ repo_name = "facebook/" + model_name
37
+ model_cfg = omegaconf.OmegaConf.load(
38
+ hf_hub_download(repo_id=repo_name, filename=FILENAME,token=HF_TOKEN)
39
+ )
40
+ LARGE_MODEL_TUPLE = utils.instantiate(model_cfg)
41
+ LARGE_MODEL_TUPLE[0].eval()
42
+ model = LARGE_MODEL_TUPLE
43
+ elif model_name == 'vc1-base':
44
+ model = BASE_MODEL_TUPLE
45
+ elif model_name == 'vc1-large':
46
+ model = LARGE_MODEL_TUPLE
47
+
48
+ return model
49
+
50
+ def download_bin(model):
51
+ bin_file = ""
52
+ if model == "vc1-large":
53
+ bin_file = 'vc1_vitl.pth'
54
+ elif model == "vc1-base":
55
+ bin_file = 'vc1_vitb.pth'
56
+ else:
57
+ raise NameError("model not found: " + model)
58
+
59
+ repo_name = 'facebook/' + model
60
+ bin_path = os.path.join(MODEL_DIR,bin_file)
61
+ if not os.path.isfile(bin_path):
62
+ model_bin = hf_hub_download(repo_id=repo_name, filename='pytorch_model.bin',local_dir=MODEL_DIR,local_dir_use_symlinks=True,token=HF_TOKEN)
63
+ os.rename(model_bin, bin_path)
64
+
65
+
66
+ def run_attn(input_img, model="vc1-large",discard_ratio=0.89):
67
+ download_bin(model)
68
+ model, embedding_dim, transform, metadata = get_model(model)
69
+ if input_img.shape[0] != 3:
70
+ input_img = input_img.transpose(2, 0, 1)
71
+ if(len(input_img.shape)== 3):
72
+ input_img = torch.tensor(input_img).unsqueeze(0)
73
+ input_img = input_img.float()
74
+ resize_transform = torchvision.transforms.Resize((250,250))
75
+ input_img = resize_transform(input_img)
76
+ x = transform(input_img)
77
+
78
+ attention_rollout = VITAttentionGradRollout(model,head_fusion="max",discard_ratio=discard_ratio)
79
+
80
+ y = model(x)
81
+ mask = attention_rollout.get_attn_mask()
82
+ attn_img = overlay_attn(input_img[0].permute(1,2,0),mask)
83
+ return attn_img
84
+
85
+ model_type = gr.Dropdown(
86
+ ["vc1-base", "vc1-large"], label="Model Size", value="vc1-large")
87
+ input_img = gr.Image(shape=(250,250))
88
+ output_img = gr.Image(shape=(250,250))
89
+ discard_ratio = gr.Slider(0,1,value=0.89)
90
+ css = "#component-2, .input-image, .image-preview {height: 240px !important}"
91
+ markdown ="This is a demo for the Visual Cortex models. When passed an image input, it displays the attention of the last layer of the transformer."
92
+
93
+ demo = gr.Interface(fn=run_attn, title="Visual Cortex Large Model", description=markdown,
94
+ examples=[[os.path.join('./imgs',x),None,None]for x in os.listdir(os.path.join(os.getcwd(),'imgs')) if 'jpg' in x],
95
+ inputs=[input_img,model_type,discard_ratio],outputs=output_img,css=css)
96
+ demo.launch()
attn_helper.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from PIL import Image
3
+ import numpy as np
4
+ import torch
5
+
6
+ import PIL
7
+
8
+ def overlay_attn(original_image,mask):
9
+ # Colormap and alpha for attention mask
10
+ # COLORMAP_OCEAN
11
+ # COLORMAP_OCEAN
12
+ colormap_attn, alpha_attn = cv2.COLORMAP_VIRIDIS, 1 #0.85
13
+
14
+ # Resize mask to original image size
15
+ w, h = original_image.shape[0], original_image.shape[1]
16
+ mask = cv2.resize(mask / mask.max(), (h, w))[..., np.newaxis]
17
+
18
+ # Apply colormap to mask
19
+ cmap = cv2.applyColorMap(np.uint8(255 * mask), colormap_attn)
20
+
21
+ # Blend mask and original image
22
+ # grayscale_img = cv2.cvtColor(np.uint8(original_image), cv2.COLOR_RGB2GRAY)
23
+ # grayscale_img = cv2.cvtColor(grayscale_img, cv2.COLOR_GRAY2RGB)
24
+ # alpha_blended = cv2.addWeighted(np.uint8(original_image),1, cmap, alpha_attn, 0)
25
+ alpha_blended = cv2.addWeighted(np.uint8(original_image),0.4, cmap, 0.6, 0)
26
+
27
+
28
+ # alpha_blended = cmap
29
+
30
+
31
+ # Save image
32
+ final_im = Image.fromarray(alpha_blended)
33
+ # final_im = final_im.crop((0,0,250,250))
34
+ return final_im
35
+
36
+
37
+
38
+ class VITAttentionGradRollout:
39
+ '''
40
+ Expects timm ViT transformer model
41
+ Adapted from https://github.com/samiraabnar/attention_flow
42
+ '''
43
+ def __init__(self, model, head_fusion='min', discard_ratio=0):
44
+ self.model = model
45
+ self.head_fusion = head_fusion
46
+ self.discard_ratio = discard_ratio
47
+
48
+ self.attentions = {}
49
+ for idx, module in enumerate(list(model.blocks.children())):
50
+ module.attn.register_forward_hook(self.get_attention(f"attn{idx}"))
51
+
52
+
53
+ def get_attention(self, name):
54
+ def hook(module, input, output):
55
+ with torch.no_grad():
56
+ input = input[0]
57
+ B, N, C = input.shape
58
+ qkv = (
59
+ module.qkv(input)
60
+ .detach()
61
+ .reshape(B, N, 3, module.num_heads, C // module.num_heads)
62
+ .permute(2, 0, 3, 1, 4)
63
+ )
64
+ q, k, _ = (
65
+ qkv[0],
66
+ qkv[1],
67
+ qkv[2],
68
+ ) # make torchscript happy (cannot use tensor as tuple)
69
+ attn = (q @ k.transpose(-2, -1)) * module.scale
70
+ attn = attn.softmax(dim=-1)
71
+ self.attentions[name] = attn
72
+ return hook
73
+
74
+ def get_attn_mask(self,k=0):
75
+ attn_key = "attn" + str()
76
+ result = torch.eye(self.attentions['attn0'].size(-1)).to(self.attentions['attn0'].device)
77
+
78
+ # result = torch.eye(self.attentions['attn2'].size(-1)).to(self.attentions['attn2'].device)
79
+ with torch.no_grad():
80
+ # for attention in self.attentions.values():
81
+ for k in range(11, len(self.attentions.keys())):
82
+ attention = self.attentions[f'attn{k}']
83
+ if self.head_fusion == "mean":
84
+ attention_heads_fused = attention.mean(axis=1)
85
+ elif self.head_fusion == "max":
86
+ attention_heads_fused = attention.max(axis=1)[0]
87
+ elif self.head_fusion == "min":
88
+ attention_heads_fused = attention.min(axis=1)[0]
89
+ else:
90
+ raise "Attention head fusion type Not supported"
91
+
92
+ # Drop the lowest attentions, but
93
+ # don't drop the class token
94
+ flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
95
+ _, indices = flat.topk(int(flat.size(-1)*self.discard_ratio), -1, False)
96
+ indices = indices[indices != 0]
97
+ flat[0, indices] = 0
98
+ I = torch.eye(attention_heads_fused.size(-1)).to(attention_heads_fused.device)
99
+ a = (attention_heads_fused + 1.0*I)/2
100
+ a = a / a.sum(dim=-1).unsqueeze(-1)
101
+
102
+ result = torch.matmul(a, result)
103
+
104
+ # Look at the total attention between the class token,
105
+ # and the image patches
106
+ mask = result[0, 0 , 1 :]
107
+ # In case of 224x224 image, this brings us from 196 to 14
108
+ width = int(mask.size(-1)**0.5)
109
+ mask = mask.reshape(width, width).detach().cpu().numpy()
110
+ mask = mask / np.max(mask)
111
+ return mask
imgs/adroit1.jpg ADDED
imgs/cheetah.jpg ADDED
imgs/ego4d.jpg ADDED
imgs/ego4d_2.jpg ADDED
imgs/ego4d_3.jpg ADDED
imgs/kitchen.jpg ADDED
imgs/reacher.jpg ADDED
imgs/rearrange.jpg ADDED
imgs/trifinger1.jpg ADDED
imgs/walker.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ omegaconf
2
+ pillow
3
+ opencv-python
4
+ torch
5
+ numpy
6
+ hydra-core
7
+ gradio
8
+ huggingface_hub
9
+ matplotlib
10
+ git+https://github.com/facebookresearch/eai-vc.git@main#subdirectory=vc_models