import gradio as gr import torch from transformers import ViTFeatureExtractor, ViTModel from skops import hub_utils from einops import reduce from torchvision.transforms.functional import to_pil_image import matplotlib.pyplot as plt import seaborn as sns import pickle import os labels = [ 'tench', 'English springer', 'cassette player', 'chain saw', 'church', 'French horn', 'garbage truck', 'gas pump', 'golf ball', 'parachute' ] # load DINO device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') feature_extractor = ViTFeatureExtractor.from_pretrained('facebook/dino-vitb16') model = ViTModel.from_pretrained('facebook/dino-vitb16').eval().to(device) # load logistic regression os.mkdir('emb-gam-dino') hub_utils.download(repo_id='Ramos-Ramos/emb-gam-dino', dst='emb-gam-dino') with open('emb-gam-dino/model.pkl', 'rb') as file: logistic_regression = pickle.load(file) def classify_and_heatmap(input_img): # get patch embeddings inputs = {k: v.to(device) for k, v in feature_extractor(input_img, return_tensors='pt').items()} with torch.no_grad(): patch_embeddings = model(**inputs).last_hidden_state[0, 1:].cpu() # get scores scores = dict(zip( labels, logistic_regression.predict_proba(reduce(patch_embeddings, 'p d -> () d', 'sum'))[0] )) # make plot num_patches_side = model.config.image_size // model.config.patch_size # set up figure fig, axs = plt.subplots(2, 6, figsize=(12, 5)) gs = axs[0, 0].get_gridspec() for ax in axs[:, 0]: ax.remove() ax_orig_img = fig.add_subplot(gs[:, 0]) # plot original image img = to_pil_image( inputs['pixel_values'].squeeze(0) * torch.tensor(feature_extractor.image_std).view(-1, 1, 1) + torch.tensor(feature_extractor.image_mean).view(-1, 1, 1) ) ax_orig_img.imshow(img) ax_orig_img.axis('off') # plot patch contributions patch_contributions = ( logistic_regression.coef_ \ @ patch_embeddings.T.numpy() \ + logistic_regression.intercept_.reshape(-1, 1) / (num_patches_side ** 2) ).reshape(-1, num_patches_side, num_patches_side) vmin = patch_contributions.min() vmax = patch_contributions.max() # print(len(list(axs[:, 1:].flat))) for i, ax in enumerate(axs[:, 1:].flat): sns.heatmap( patch_contributions[i].reshape(num_patches_side, num_patches_side), ax=ax, square=True, vmin=vmin, vmax=vmax, ) ax.set_title(labels[i]) ax.set_xlabel(f'score={patch_contributions[i].sum():.2f}') ax.set_xticks([]) ax.set_yticks([]) return scores, plt description=''' This demo is a simple extension of [Emb-GAM (Singh & Gao, 2022)](https://arxiv.org/abs/2209.11799) to images. It does image classification on [Imagenette](https://github.com/fastai/imagenette) and visualizes the contrbutions of each image patch to each label. ''' article=''' Under the hood, we use [DINO](https://arxiv.org/abs/2104.14294) to extract patch embeddings and a logistic regression model following the set up of the [offical Emb-GAM implementation](https://github.com/csinva/emb-gam). Citation for stuff involved (not our papers): ```bibtex @article{singh2022emb, title={Emb-GAM: an Interpretable and Efficient Predictor using Pre-trained Language Models}, author={Singh, Chandan and Gao, Jianfeng}, journal={arXiv preprint arXiv:2209.11799}, year={2022} } @InProceedings{Caron_2021_ICCV, author = {Caron, Mathilde and Touvron, Hugo and Misra, Ishan and J\'egou, Herv\'e and Mairal, Julien and Bojanowski, Piotr and Joulin, Armand}, title = {Emerging Properties in Self-Supervised Vision Transformers}, booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, month = {October}, year = {2021}, pages = {9650-9660} } @misc{imagenette, author = {fast.ai}, title = {Imagenette}, url = {https://github.com/fastai/imagenette}, } ``` ''' demo = gr.Interface( fn=classify_and_heatmap, inputs=gr.Image(shape=(224, 224), type='pil', label='Input Image'), outputs=[ gr.Label(label='Class'), gr.Plot(label='Patch Contributions') ], title='Emb-GAM DINO', description=description, article=article, examples=['./examples/english_springer.png', './examples/golf_ball.png'] ) demo.launch(debug=True)