Stefan Denner
Initial commit
208214b
import math
import types
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import Compose, Resize, InterpolationMode
import open_clip
from open_clip.transformer import VisionTransformer
from open_clip.timm_model import TimmModel
from einops import rearrange
from .utils import (
hooked_attention_timm_forward,
hooked_resblock_forward,
hooked_attention_forward,
hooked_resblock_timm_forward,
hooked_attentional_pooler_timm_forward,
vit_dynamic_size_forward,
min_max,
hooked_torch_multi_head_attention_forward,
)
class LeWrapper(nn.Module):
"""
Wrapper around OpenCLIP to add LeGrad to OpenCLIP's model while keep all the functionalities of the original model.
"""
def __init__(self, model, layer_index=-2):
super(LeWrapper, self).__init__()
# ------------ copy of model's attributes and methods ------------
for attr in dir(model):
if not attr.startswith("__"):
setattr(self, attr, getattr(model, attr))
# ------------ activate hooks & gradient ------------
self._activate_hooks(layer_index=layer_index)
def _activate_hooks(self, layer_index):
# ------------ identify model's type ------------
print("Activating necessary hooks and gradients ....")
if isinstance(self.visual, VisionTransformer):
# --- Activate dynamic image size ---
self.visual.forward = types.MethodType(
vit_dynamic_size_forward, self.visual
)
# Get patch size
self.patch_size = self.visual.patch_size[0]
# Get starting depth (in case of negative layer_index)
self.starting_depth = (
layer_index
if layer_index >= 0
else len(self.visual.transformer.resblocks) + layer_index
)
if self.visual.attn_pool is None:
self.model_type = "clip"
self._activate_self_attention_hooks()
else:
self.model_type = "coca"
self._activate_att_pool_hooks(layer_index=layer_index)
elif isinstance(self.visual, TimmModel):
# --- Activate dynamic image size ---
self.visual.trunk.dynamic_img_size = True
self.visual.trunk.patch_embed.dynamic_img_size = True
self.visual.trunk.patch_embed.strict_img_size = False
self.visual.trunk.patch_embed.flatten = False
self.visual.trunk.patch_embed.output_fmt = "NHWC"
self.model_type = "timm_siglip"
# --- Get patch size ---
self.patch_size = self.visual.trunk.patch_embed.patch_size[0]
# --- Get starting depth (in case of negative layer_index) ---
self.starting_depth = (
layer_index
if layer_index >= 0
else len(self.visual.trunk.blocks) + layer_index
)
if (
hasattr(self.visual.trunk, "attn_pool")
and self.visual.trunk.attn_pool is not None
):
self._activate_timm_attn_pool_hooks(layer_index=layer_index)
else:
self._activate_timm_self_attention_hooks()
else:
raise ValueError(
"Model currently not supported, see legrad.list_pretrained() for a list of available models"
)
print("Hooks and gradients activated!")
def _activate_self_attention_hooks(self):
# Adjusting to use the correct structure
if isinstance(self.visual, VisionTransformer):
blocks = self.visual.transformer.resblocks
elif isinstance(self.visual, TimmModel):
blocks = self.visual.trunk.blocks
else:
raise ValueError("Unsupported model type for self-attention hooks")
# ---------- Apply Hooks + Activate/Deactivate gradients ----------
# Necessary steps to get intermediate representations
for name, param in self.named_parameters():
param.requires_grad = False
if name.startswith("visual.trunk.blocks"):
depth = int(name.split("visual.trunk.blocks.")[-1].split(".")[0])
if depth >= self.starting_depth:
param.requires_grad = True
# --- Activate the hooks for the specific layers ---
for layer in range(self.starting_depth, len(blocks)):
blocks[layer].attn.forward = types.MethodType(
hooked_attention_forward, blocks[layer].attn
)
blocks[layer].forward = types.MethodType(
hooked_resblock_forward, blocks[layer]
)
def _activate_timm_self_attention_hooks(self):
# Adjusting to use the correct structure
blocks = self.visual.trunk.blocks
# ---------- Apply Hooks + Activate/Deactivate gradients ----------
# Necessary steps to get intermediate representations
for name, param in self.named_parameters():
param.requires_grad = False
if name.startswith("visual.trunk.blocks"):
depth = int(name.split("visual.trunk.blocks.")[-1].split(".")[0])
if depth >= self.starting_depth:
param.requires_grad = True
# --- Activate the hooks for the specific layers ---
for layer in range(self.starting_depth, len(blocks)):
blocks[layer].attn.forward = types.MethodType(
hooked_attention_timm_forward, blocks[layer].attn
)
blocks[layer].forward = types.MethodType(
hooked_resblock_timm_forward, blocks[layer]
)
def _activate_att_pool_hooks(self, layer_index):
# ---------- Apply Hooks + Activate/Deactivate gradients ----------
# Necessary steps to get intermediate representations
for name, param in self.named_parameters():
param.requires_grad = False
if name.startswith("visual.transformer.resblocks"):
# get the depth
depth = int(
name.split("visual.transformer.resblocks.")[-1].split(".")[0]
)
if depth >= self.starting_depth:
param.requires_grad = True
# --- Activate the hooks for the specific layers ---
for layer in range(self.starting_depth, len(self.visual.transformer.resblocks)):
self.visual.transformer.resblocks[layer].forward = types.MethodType(
hooked_resblock_forward, self.visual.transformer.resblocks[layer]
)
# --- Apply hook on the attentional pooler ---
self.visual.attn_pool.attn.forward = types.MethodType(
hooked_torch_multi_head_attention_forward, self.visual.attn_pool.attn
)
def _activate_timm_attn_pool_hooks(self, layer_index):
# Ensure all components are present before attaching hooks
if (
not hasattr(self.visual.trunk, "attn_pool")
or self.visual.trunk.attn_pool is None
):
raise ValueError("Attentional pooling not found in TimmModel")
self.visual.trunk.attn_pool.forward = types.MethodType(
hooked_attentional_pooler_timm_forward, self.visual.trunk.attn_pool
)
for block in self.visual.trunk.blocks:
if hasattr(block, "attn"):
block.attn.forward = types.MethodType(
hooked_attention_forward, block.attn
)
# --- Deactivate gradient for module that don't need it ---
for name, param in self.named_parameters():
param.requires_grad = False
if name.startswith("visual.trunk.attn_pool"):
param.requires_grad = True
if name.startswith("visual.trunk.blocks"):
# get the depth
depth = int(name.split("visual.trunk.blocks.")[-1].split(".")[0])
if depth >= self.starting_depth:
param.requires_grad = True
# --- Activate the hooks for the specific layers by modifying the block's forward ---
for layer in range(self.starting_depth, len(self.visual.trunk.blocks)):
self.visual.trunk.blocks[layer].forward = types.MethodType(
hooked_resblock_timm_forward, self.visual.trunk.blocks[layer]
)
self.visual.trunk.attn_pool.forward = types.MethodType(
hooked_attentional_pooler_timm_forward, self.visual.trunk.attn_pool
)
def compute_legrad(self, text_embedding, image=None, apply_correction=True):
if "clip" in self.model_type:
return self.compute_legrad_clip(text_embedding, image)
elif "siglip" in self.model_type:
return self.compute_legrad_siglip(
text_embedding, image, apply_correction=apply_correction
)
elif "coca" in self.model_type:
return self.compute_legrad_coca(text_embedding, image)
def compute_legrad_clip(self, text_embedding, image=None):
num_prompts = text_embedding.shape[0]
if image is not None:
# Ensure the image is passed through the model to get the intermediate features
_ = self.encode_image(image)
blocks_list = list(dict(self.visual.trunk.blocks.named_children()).values())
image_features_list = []
for layer in range(self.starting_depth, len(self.visual.trunk.blocks)):
# [num_patch, batch, dim]
intermediate_feat = blocks_list[layer].feat_post_mlp
# Mean over the patch tokens
intermediate_feat = intermediate_feat.mean(dim=1)
intermediate_feat = self.visual.head(
self.visual.trunk.norm(intermediate_feat)
)
intermediate_feat = F.normalize(intermediate_feat, dim=-1)
image_features_list.append(intermediate_feat)
num_tokens = blocks_list[-1].feat_post_mlp.shape[1] - 1
w = h = int(math.sqrt(num_tokens))
# ----- Get explainability map
accum_expl_map = 0
for layer, (blk, img_feat) in enumerate(
zip(blocks_list[self.starting_depth :], image_features_list)
):
self.visual.zero_grad()
sim = text_embedding @ img_feat.transpose(-1, -2) # [1, 1]
one_hot = (
F.one_hot(torch.arange(0, num_prompts))
.float()
.requires_grad_(True)
.to(text_embedding.device)
)
one_hot = torch.sum(one_hot * sim)
# [b, num_heads, N, N]
attn_map = blocks_list[self.starting_depth + layer].attn.attention_map
# -------- Get explainability map --------
# [batch_size * num_heads, N, N]
grad = torch.autograd.grad(
one_hot, [attn_map], retain_graph=True, create_graph=True
)[0]
# grad = rearrange(grad, '(b h) n m -> b h n m', b=num_prompts) # separate batch and attn heads
grad = torch.clamp(grad, min=0.0)
# average attn over [CLS] + patch tokens
image_relevance = grad.mean(dim=1).mean(dim=1)[:, 1:]
expl_map = rearrange(image_relevance, "b (w h) -> 1 b w h", w=w, h=h)
# [B, 1, H, W]
expl_map = F.interpolate(
expl_map, scale_factor=self.patch_size, mode="bilinear"
)
accum_expl_map += expl_map
# Min-Max Norm
accum_expl_map = min_max(accum_expl_map)
return accum_expl_map
def compute_legrad_coca(self, text_embedding, image=None):
if image is not None:
_ = self.encode_image(image)
blocks_list = list(
dict(self.visual.transformer.resblocks.named_children()).values()
)
image_features_list = []
for layer in range(self.starting_depth, len(self.visual.transformer.resblocks)):
intermediate_feat = self.visual.transformer.resblocks[
layer
].feat_post_mlp # [num_patch, batch, dim]
intermediate_feat = intermediate_feat.permute(
1, 0, 2
) # [batch, num_patch, dim]
image_features_list.append(intermediate_feat)
num_tokens = blocks_list[-1].feat_post_mlp.shape[0] - 1
w = h = int(math.sqrt(num_tokens))
# ----- Get explainability map
accum_expl_map = 0
for layer, (blk, img_feat) in enumerate(
zip(blocks_list[self.starting_depth :], image_features_list)
):
self.visual.zero_grad()
# --- Apply attn_pool ---
image_embedding = self.visual.attn_pool(img_feat)[
:, 0
] # we keep only the first pooled token as it is only this one trained with the contrastive loss
image_embedding = image_embedding @ self.visual.proj
sim = text_embedding @ image_embedding.transpose(-1, -2) # [1, 1]
one_hot = torch.sum(sim)
attn_map = (
self.visual.attn_pool.attn.attention_maps
) # [num_heads, num_latent, num_patch]
# -------- Get explainability map --------
grad = torch.autograd.grad(
one_hot, [attn_map], retain_graph=True, create_graph=True
)[
0
] # [num_heads, num_latent, num_patch]
grad = torch.clamp(grad, min=0.0)
image_relevance = grad.mean(dim=0)[
0, 1:
] # average attn over heads + select first latent
expl_map = rearrange(image_relevance, "(w h) -> 1 1 w h", w=w, h=h)
expl_map = F.interpolate(
expl_map, scale_factor=self.patch_size, mode="bilinear"
) # [B, 1, H, W]
accum_expl_map += expl_map
# Min-Max Norm
accum_expl_map = (accum_expl_map - accum_expl_map.min()) / (
accum_expl_map.max() - accum_expl_map.min()
)
return accum_expl_map
def _init_empty_embedding(self):
if not hasattr(self, "empty_embedding"):
# For the moment only SigLIP is supported & they all have the same tokenizer
_tok = open_clip.get_tokenizer(model_name="ViT-B-16-SigLIP")
empty_text = _tok(["a photo of a"]).to(self.logit_scale.data.device)
empty_embedding = self.encode_text(empty_text)
empty_embedding = F.normalize(empty_embedding, dim=-1)
self.empty_embedding = empty_embedding.t()
def compute_legrad_siglip(
self,
text_embedding,
image=None,
apply_correction=True,
correction_threshold=0.8,
):
# --- Forward CLIP ---
blocks_list = list(dict(self.visual.trunk.blocks.named_children()).values())
if image is not None:
_ = self.encode_image(image) # [bs, num_patch, dim] bs=num_masks
image_features_list = []
for blk in blocks_list[self.starting_depth :]:
intermediate_feat = blk.feat_post_mlp
image_features_list.append(intermediate_feat)
num_tokens = blocks_list[-1].feat_post_mlp.shape[1]
w = h = int(math.sqrt(num_tokens))
if apply_correction:
self._init_empty_embedding()
accum_expl_map_empty = 0
accum_expl_map = 0
for layer, (blk, img_feat) in enumerate(
zip(blocks_list[self.starting_depth :], image_features_list)
):
self.zero_grad()
pooled_feat = self.visual.trunk.attn_pool(img_feat)
pooled_feat = F.normalize(pooled_feat, dim=-1)
# -------- Get explainability map --------
sim = text_embedding @ pooled_feat.transpose(-1, -2) # [num_mask, num_mask]
one_hot = torch.sum(sim)
grad = torch.autograd.grad(
one_hot,
[self.visual.trunk.attn_pool.attn_probs],
retain_graph=True,
create_graph=True,
)[0]
grad = torch.clamp(grad, min=0.0)
image_relevance = grad.mean(dim=1)[
:, 0
] # average attn over [CLS] + patch tokens
expl_map = rearrange(image_relevance, "b (w h) -> b 1 w h", w=w, h=h)
accum_expl_map += expl_map
if apply_correction:
# -------- Get empty explainability map --------
sim_empty = pooled_feat @ self.empty_embedding
one_hot_empty = torch.sum(sim_empty)
grad_empty = torch.autograd.grad(
one_hot_empty,
[self.visual.trunk.attn_pool.attn_probs],
retain_graph=True,
create_graph=True,
)[0]
grad_empty = torch.clamp(grad_empty, min=0.0)
image_relevance_empty = grad_empty.mean(dim=1)[
:, 0
] # average attn over heads + select query's row
expl_map_empty = rearrange(
image_relevance_empty, "b (w h) -> b 1 w h", w=w, h=h
)
accum_expl_map_empty += expl_map_empty
if apply_correction:
heatmap_empty = min_max(accum_expl_map_empty)
accum_expl_map[heatmap_empty > correction_threshold] = 0
Res = min_max(accum_expl_map)
Res = F.interpolate(
Res, scale_factor=self.patch_size, mode="bilinear"
) # [B, 1, H, W]
return Res
class LePreprocess(nn.Module):
"""
Modify OpenCLIP preprocessing to accept arbitrary image size.
"""
def __init__(self, preprocess, image_size):
super(LePreprocess, self).__init__()
self.transform = Compose(
[
Resize(
(image_size, image_size), interpolation=InterpolationMode.BICUBIC
),
preprocess.transforms[-3],
preprocess.transforms[-2],
preprocess.transforms[-1],
]
)
def forward(self, image):
return self.transform(image)