import numpy as np import gradio as gr from huggingface_hub import hf_hub_download import omegaconf from hydra import utils import os import torch import matplotlib.pyplot as plt from attn_helper import VITAttentionGradRollout, overlay_attn import vc_models #import eaif_models import torchvision HF_TOKEN = os.environ['HF_ACC_TOKEN'] eai_filepath = vc_models.__file__.split('src')[0] MODEL_DIR=os.path.join(os.path.dirname(eai_filepath),'model_ckpts') if not os.path.isdir(MODEL_DIR): os.mkdir(MODEL_DIR) FILENAME = "config.yaml" BASE_MODEL_TUPLE = None LARGE_MODEL_TUPLE = None def get_model(model_name): global BASE_MODEL_TUPLE,LARGE_MODEL_TUPLE download_bin(model_name) model = None if BASE_MODEL_TUPLE is None and model_name == 'vc1-base': repo_name = "facebook/" + model_name model_cfg = omegaconf.OmegaConf.load( hf_hub_download(repo_id=repo_name, filename=FILENAME,token=HF_TOKEN) ) BASE_MODEL_TUPLE = utils.instantiate(model_cfg) BASE_MODEL_TUPLE[0].eval() model = BASE_MODEL_TUPLE elif LARGE_MODEL_TUPLE is None and model_name == 'vc1-large': repo_name = "facebook/" + model_name model_cfg = omegaconf.OmegaConf.load( hf_hub_download(repo_id=repo_name, filename=FILENAME,token=HF_TOKEN) ) LARGE_MODEL_TUPLE = utils.instantiate(model_cfg) LARGE_MODEL_TUPLE[0].eval() model = LARGE_MODEL_TUPLE elif model_name == 'vc1-base': model = BASE_MODEL_TUPLE elif model_name == 'vc1-large': model = LARGE_MODEL_TUPLE return model def download_bin(model): bin_file = "" if model == "vc1-large": bin_file = 'vc1_vitl.pth' elif model == "vc1-base": bin_file = 'vc1_vitb.pth' else: raise NameError("model not found: " + model) repo_name = 'facebook/' + model bin_path = os.path.join(MODEL_DIR,bin_file) if not os.path.isfile(bin_path): model_bin = hf_hub_download(repo_id=repo_name, filename='pytorch_model.bin',local_dir=MODEL_DIR,local_dir_use_symlinks=True,token=HF_TOKEN) os.rename(model_bin, bin_path) def run_attn(input_img, model="vc1-large",fusion="min",slider=0): download_bin(model) model, embedding_dim, transform, metadata = get_model(model) if input_img.shape[0] != 3: input_img = input_img.transpose(2, 0, 1) if(len(input_img.shape)== 3): input_img = torch.tensor(input_img).unsqueeze(0) input_img = input_img.float() resize_transform = torchvision.transforms.Resize((250,250)) input_img = resize_transform(input_img) x = transform(input_img) attention_rollout = VITAttentionGradRollout(model,head_fusion=fusion,discard_ratio=slider) y = model(x) mask = attention_rollout.get_attn_mask() attn_img = overlay_attn(input_img[0].permute(1,2,0),mask) fig = plt.figure() ax = fig.subplots() print(y.shape) im = ax.matshow(y.detach().numpy().reshape(16,-1)) plt.colorbar(im) return attn_img, fig model_type = gr.Dropdown( ["vc1-base", "vc1-large"], label="Model Size", value="vc1-large") input_img = gr.Image(shape=(250,250)) input_button = gr.Radio(["min", "max", "mean"], value="min",label="Attention Head Fusion", info="How to combine the last layer attention across all 12 heads of the transformer.") output_img = gr.Image(shape=(250,250)) output_plot = gr.Plot() css = "#component-3, .input-image, .image-preview {height: 240px !important}" slider = gr.Slider(0, 1) markdown ="This is a demo for the Visual Cortex models. When passed an image input, it displays the attention of the last layer of the transformer.\n \ The user can decide how the attention heads will be combined. \ Along with the attention heatmap, it also displays the embedding values reshaped to a 16x48 or 16x64 grid." demo = gr.Interface(fn=run_attn, title="Visual Cortex Large Model", description=markdown, examples=[[os.path.join('./imgs',x),None,None,None]for x in os.listdir(os.path.join(os.getcwd(),'imgs')) if 'jpg' in x], inputs=[input_img,model_type,input_button,slider],outputs=[output_img,output_plot],css=css) demo.launch()