harsh001's picture
Duplicate from harsh001/Explainability
4349468
raw
history blame
4.16 kB
import torch
import clip
from PIL import Image
import numpy as np
import cv2
import matplotlib.pyplot as plt
def interpret(image, text, model, device, index=None):
logits_per_image, logits_per_text = model(image, text)
probs = logits_per_image.softmax(dim=-1).detach().cpu().numpy()
if index is None:
index = np.argmax(logits_per_image.cpu().data.numpy(), axis=-1)
one_hot = np.zeros((1, logits_per_image.size()[-1]), dtype=np.float32)
one_hot[0, index] = 1
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
one_hot = torch.sum(one_hot.cuda() * logits_per_image)
model.zero_grad()
one_hot.backward(retain_graph=True)
image_attn_blocks = list(dict(model.visual.transformer.resblocks.named_children()).values())
num_tokens = image_attn_blocks[0].attn_probs.shape[-1]
R = torch.eye(num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype).to(device)
for blk in image_attn_blocks:
grad = blk.attn_grad
cam = blk.attn_probs
cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1])
cam = grad * cam
cam = cam.clamp(min=0).mean(dim=0)
R += torch.matmul(cam, R)
R[0, 0] = 0
image_relevance = R[0, 1:]
# create heatmap from mask on image
def show_cam_on_image(img, mask):
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap) / 255
cam = heatmap + np.float32(img)
cam = cam / np.max(cam)
return cam
image_relevance = image_relevance.reshape(1, 1, 7, 7)
image_relevance = torch.nn.functional.interpolate(image_relevance, size=224, mode='bilinear')
image_relevance = image_relevance.reshape(224, 224).cuda().data.cpu().numpy()
image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min())
image = image[0].permute(1, 2, 0).data.cpu().numpy()
image = (image - image.min()) / (image.max() - image.min())
vis = show_cam_on_image(image, image_relevance)
vis = np.uint8(255 * vis)
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
plt.imshow(vis)
plt.show()
print("Label probs:", probs)
def main():
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
image = preprocess(Image.open("catdog.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a dog", "a cat"]).to(device)
interpret(model=model, image=image, text=text, device=device, index=0)
interpret(model=model, image=image, text=text, device=device, index=1)
image = preprocess(Image.open("el1.png")).unsqueeze(0).to(device)
text = clip.tokenize(["an elephant", "a zebra"]).to(device)
interpret(model=model, image=image, text=text, device=device, index=0)
interpret(model=model, image=image, text=text, device=device, index=1)
image = preprocess(Image.open("el2.png")).unsqueeze(0).to(device)
text = clip.tokenize(["an elephant", "a zebra"]).to(device)
interpret(model=model, image=image, text=text, device=device, index=0)
interpret(model=model, image=image, text=text, device=device, index=1)
image = preprocess(Image.open("el3.png")).unsqueeze(0).to(device)
text = clip.tokenize(["an elephant", "a zebra"]).to(device)
interpret(model=model, image=image, text=text, device=device, index=0)
interpret(model=model, image=image, text=text, device=device, index=1)
image = preprocess(Image.open("el4.png")).unsqueeze(0).to(device)
text = clip.tokenize(["an elephant", "a zebra"]).to(device)
interpret(model=model, image=image, text=text, device=device, index=0)
interpret(model=model, image=image, text=text, device=device, index=1)
image = preprocess(Image.open("dogbird.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a basset hound", "a parrot"]).to(device)
interpret(model=model, image=image, text=text, device=device, index=0)
interpret(model=model, image=image, text=text, device=device, index=1)
if __name__ == "__main__":
main()