File size: 3,666 Bytes
cf37148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30f2229
cf37148
 
30f2229
cf37148
 
 
4ef25ba
cf37148
 
 
 
 
 
 
 
 
 
 
4ef25ba
cf37148
 
 
 
ffecbe9
cf37148
 
 
 
 
4ef25ba
ffecbe9
4ef25ba
 
cf37148
4ef25ba
 
cf37148
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import numpy as np
import gradio as gr
from huggingface_hub import hf_hub_download
import omegaconf
from hydra import utils
import os
import torch
import matplotlib.pyplot as plt
from attn_helper import VITAttentionGradRollout, overlay_attn
import vc_models
import torchvision


HF_TOKEN = os.environ['HF_ACC_TOKEN']
eai_filepath = vc_models.__file__.split('src')[0]
MODEL_DIR=os.path.join(os.path.dirname(eai_filepath),'model_ckpts')
if not os.path.isdir(MODEL_DIR):
    os.mkdir(MODEL_DIR)

FILENAME = "config.yaml"
BASE_MODEL_TUPLE = None
LARGE_MODEL_TUPLE = None
def get_model(model_name):
    global BASE_MODEL_TUPLE,LARGE_MODEL_TUPLE
    download_bin(model_name)
    model = None
    if BASE_MODEL_TUPLE is None and model_name == 'vc1-base':
        repo_name = "facebook/" + model_name
        model_cfg = omegaconf.OmegaConf.load(
            hf_hub_download(repo_id=repo_name, filename=FILENAME,token=HF_TOKEN)
        )
        BASE_MODEL_TUPLE = utils.instantiate(model_cfg)
        BASE_MODEL_TUPLE[0].eval()
        model =  BASE_MODEL_TUPLE
    elif LARGE_MODEL_TUPLE is None and model_name == 'vc1-large':
        repo_name = "facebook/" + model_name
        model_cfg = omegaconf.OmegaConf.load(
            hf_hub_download(repo_id=repo_name, filename=FILENAME,token=HF_TOKEN)
        )
        LARGE_MODEL_TUPLE = utils.instantiate(model_cfg)
        LARGE_MODEL_TUPLE[0].eval()
        model =  LARGE_MODEL_TUPLE
    elif model_name == 'vc1-base':
        model = BASE_MODEL_TUPLE
    elif model_name == 'vc1-large':
        model = LARGE_MODEL_TUPLE

    return model

def download_bin(model):
    bin_file = ""
    if model == "vc1-large":
        bin_file = 'vc1_vitl.pth'
    elif model == "vc1-base":
        bin_file = 'vc1_vitb.pth'
    else:
        raise NameError("model not found: " + model)
    
    repo_name = 'facebook/' + model
    bin_path = os.path.join(MODEL_DIR,bin_file)
    if not os.path.isfile(bin_path):
        model_bin = hf_hub_download(repo_id=repo_name, filename='pytorch_model.bin',local_dir=MODEL_DIR,local_dir_use_symlinks=True,token=HF_TOKEN)
        os.rename(model_bin, bin_path)


def run_attn(input_img, model="vc1-large",discard_ratio=0.89):
    download_bin(model)
    model, embedding_dim, transform, metadata = get_model(model)
    if input_img.shape[0] != 3:
        input_img = input_img.transpose(2, 0, 1)
    if(len(input_img.shape)== 3):
        input_img = torch.tensor(input_img).unsqueeze(0)
    input_img = input_img.float()
    resize_transform = torchvision.transforms.Resize((250,250))
    input_img = resize_transform(input_img)
    x = transform(input_img)

    attention_rollout = VITAttentionGradRollout(model,head_fusion="max",discard_ratio=discard_ratio)
    
    y = model(x)
    mask = attention_rollout.get_attn_mask()
    attn_img = overlay_attn(input_img[0].permute(1,2,0),mask)
    return attn_img

model_type = gr.Dropdown(
            ["vc1-base", "vc1-large"], label="Model Size", value="vc1-large")
input_img = gr.Image(shape=(250,250))
output_img = gr.Image(shape=(250,250))
discard_ratio = gr.Slider(0,1,value=0.89)
css = "#component-2, .input-image, .image-preview {height: 240px !important}"
markdown ="This is a demo for the Visual Cortex models. When passed an image input, it displays the attention of the last layer of the transformer."

demo = gr.Interface(fn=run_attn, title="Visual Cortex Large Model", description=markdown,
                    examples=[[os.path.join('./imgs',x),None,None]for x in os.listdir(os.path.join(os.getcwd(),'imgs')) if 'jpg' in x],
                    inputs=[input_img,model_type,discard_ratio],outputs=output_img,css=css)
demo.launch()