|
import numpy as np |
|
import gradio as gr |
|
from huggingface_hub import hf_hub_download |
|
import omegaconf |
|
from hydra import utils |
|
import os |
|
import torch |
|
import matplotlib.pyplot as plt |
|
from attn_helper import VITAttentionGradRollout, overlay_attn |
|
import vc_models |
|
import torchvision |
|
|
|
|
|
HF_TOKEN = os.environ['HF_ACC_TOKEN'] |
|
eai_filepath = vc_models.__file__.split('src')[0] |
|
MODEL_DIR=os.path.join(os.path.dirname(eai_filepath),'model_ckpts') |
|
if not os.path.isdir(MODEL_DIR): |
|
os.mkdir(MODEL_DIR) |
|
|
|
FILENAME = "config.yaml" |
|
BASE_MODEL_TUPLE = None |
|
LARGE_MODEL_TUPLE = None |
|
def get_model(model_name): |
|
global BASE_MODEL_TUPLE,LARGE_MODEL_TUPLE |
|
download_bin(model_name) |
|
model = None |
|
if BASE_MODEL_TUPLE is None and model_name == 'vc1-base': |
|
repo_name = "facebook/" + model_name |
|
model_cfg = omegaconf.OmegaConf.load( |
|
hf_hub_download(repo_id=repo_name, filename=FILENAME,token=HF_TOKEN) |
|
) |
|
BASE_MODEL_TUPLE = utils.instantiate(model_cfg) |
|
BASE_MODEL_TUPLE[0].eval() |
|
model = BASE_MODEL_TUPLE |
|
elif LARGE_MODEL_TUPLE is None and model_name == 'vc1-large': |
|
repo_name = "facebook/" + model_name |
|
model_cfg = omegaconf.OmegaConf.load( |
|
hf_hub_download(repo_id=repo_name, filename=FILENAME,token=HF_TOKEN) |
|
) |
|
LARGE_MODEL_TUPLE = utils.instantiate(model_cfg) |
|
LARGE_MODEL_TUPLE[0].eval() |
|
model = LARGE_MODEL_TUPLE |
|
elif model_name == 'vc1-base': |
|
model = BASE_MODEL_TUPLE |
|
elif model_name == 'vc1-large': |
|
model = LARGE_MODEL_TUPLE |
|
|
|
return model |
|
|
|
def download_bin(model): |
|
bin_file = "" |
|
if model == "vc1-large": |
|
bin_file = 'vc1_vitl.pth' |
|
elif model == "vc1-base": |
|
bin_file = 'vc1_vitb.pth' |
|
else: |
|
raise NameError("model not found: " + model) |
|
|
|
repo_name = 'facebook/' + model |
|
bin_path = os.path.join(MODEL_DIR,bin_file) |
|
if not os.path.isfile(bin_path): |
|
model_bin = hf_hub_download(repo_id=repo_name, filename='pytorch_model.bin',local_dir=MODEL_DIR,local_dir_use_symlinks=True,token=HF_TOKEN) |
|
os.rename(model_bin, bin_path) |
|
|
|
|
|
def run_attn(input_img, model="vc1-large",discard_ratio=0.89): |
|
download_bin(model) |
|
model, embedding_dim, transform, metadata = get_model(model) |
|
if input_img.shape[0] != 3: |
|
input_img = input_img.transpose(2, 0, 1) |
|
if(len(input_img.shape)== 3): |
|
input_img = torch.tensor(input_img).unsqueeze(0) |
|
input_img = input_img.float() |
|
resize_transform = torchvision.transforms.Resize((250,250)) |
|
input_img = resize_transform(input_img) |
|
x = transform(input_img) |
|
|
|
attention_rollout = VITAttentionGradRollout(model,head_fusion="max",discard_ratio=discard_ratio) |
|
|
|
y = model(x) |
|
mask = attention_rollout.get_attn_mask() |
|
attn_img = overlay_attn(input_img[0].permute(1,2,0),mask) |
|
return attn_img |
|
|
|
model_type = gr.Dropdown( |
|
["vc1-base", "vc1-large"], label="Model Size", value="vc1-large") |
|
input_img = gr.Image(shape=(250,250)) |
|
output_img = gr.Image(shape=(250,250)) |
|
discard_ratio = gr.Slider(0,1,value=0.89) |
|
css = "#component-2, .input-image, .image-preview {height: 240px !important}" |
|
markdown ="This is a demo for the Visual Cortex models. When passed an image input, it displays the attention of the last layer of the transformer." |
|
|
|
demo = gr.Interface(fn=run_attn, title="Visual Cortex Large Model", description=markdown, |
|
examples=[[os.path.join('./imgs',x),None,None]for x in os.listdir(os.path.join(os.getcwd(),'imgs')) if 'jpg' in x], |
|
inputs=[input_img,model_type,discard_ratio],outputs=output_img,css=css) |
|
demo.launch() |
|
|