|
import math |
|
import types |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torchvision.transforms import Compose, Resize, InterpolationMode |
|
import open_clip |
|
from open_clip.transformer import VisionTransformer |
|
from open_clip.timm_model import TimmModel |
|
from einops import rearrange |
|
|
|
from .utils import ( |
|
hooked_attention_timm_forward, |
|
hooked_resblock_forward, |
|
hooked_attention_forward, |
|
hooked_resblock_timm_forward, |
|
hooked_attentional_pooler_timm_forward, |
|
vit_dynamic_size_forward, |
|
min_max, |
|
hooked_torch_multi_head_attention_forward, |
|
) |
|
|
|
|
|
class LeWrapper(nn.Module): |
|
""" |
|
Wrapper around OpenCLIP to add LeGrad to OpenCLIP's model while keep all the functionalities of the original model. |
|
""" |
|
|
|
def __init__(self, model, layer_index=-2): |
|
super(LeWrapper, self).__init__() |
|
|
|
for attr in dir(model): |
|
if not attr.startswith("__"): |
|
setattr(self, attr, getattr(model, attr)) |
|
|
|
|
|
self._activate_hooks(layer_index=layer_index) |
|
|
|
def _activate_hooks(self, layer_index): |
|
|
|
print("Activating necessary hooks and gradients ....") |
|
if isinstance(self.visual, VisionTransformer): |
|
|
|
self.visual.forward = types.MethodType( |
|
vit_dynamic_size_forward, self.visual |
|
) |
|
|
|
self.patch_size = self.visual.patch_size[0] |
|
|
|
self.starting_depth = ( |
|
layer_index |
|
if layer_index >= 0 |
|
else len(self.visual.transformer.resblocks) + layer_index |
|
) |
|
|
|
if self.visual.attn_pool is None: |
|
self.model_type = "clip" |
|
self._activate_self_attention_hooks() |
|
else: |
|
self.model_type = "coca" |
|
self._activate_att_pool_hooks(layer_index=layer_index) |
|
|
|
elif isinstance(self.visual, TimmModel): |
|
|
|
self.visual.trunk.dynamic_img_size = True |
|
self.visual.trunk.patch_embed.dynamic_img_size = True |
|
self.visual.trunk.patch_embed.strict_img_size = False |
|
self.visual.trunk.patch_embed.flatten = False |
|
self.visual.trunk.patch_embed.output_fmt = "NHWC" |
|
self.model_type = "timm_siglip" |
|
|
|
self.patch_size = self.visual.trunk.patch_embed.patch_size[0] |
|
|
|
self.starting_depth = ( |
|
layer_index |
|
if layer_index >= 0 |
|
else len(self.visual.trunk.blocks) + layer_index |
|
) |
|
if ( |
|
hasattr(self.visual.trunk, "attn_pool") |
|
and self.visual.trunk.attn_pool is not None |
|
): |
|
self._activate_timm_attn_pool_hooks(layer_index=layer_index) |
|
else: |
|
self._activate_timm_self_attention_hooks() |
|
else: |
|
raise ValueError( |
|
"Model currently not supported, see legrad.list_pretrained() for a list of available models" |
|
) |
|
print("Hooks and gradients activated!") |
|
|
|
def _activate_self_attention_hooks(self): |
|
|
|
if isinstance(self.visual, VisionTransformer): |
|
blocks = self.visual.transformer.resblocks |
|
elif isinstance(self.visual, TimmModel): |
|
blocks = self.visual.trunk.blocks |
|
else: |
|
raise ValueError("Unsupported model type for self-attention hooks") |
|
|
|
|
|
|
|
for name, param in self.named_parameters(): |
|
param.requires_grad = False |
|
if name.startswith("visual.trunk.blocks"): |
|
depth = int(name.split("visual.trunk.blocks.")[-1].split(".")[0]) |
|
if depth >= self.starting_depth: |
|
param.requires_grad = True |
|
|
|
|
|
for layer in range(self.starting_depth, len(blocks)): |
|
blocks[layer].attn.forward = types.MethodType( |
|
hooked_attention_forward, blocks[layer].attn |
|
) |
|
blocks[layer].forward = types.MethodType( |
|
hooked_resblock_forward, blocks[layer] |
|
) |
|
|
|
def _activate_timm_self_attention_hooks(self): |
|
|
|
blocks = self.visual.trunk.blocks |
|
|
|
|
|
|
|
for name, param in self.named_parameters(): |
|
param.requires_grad = False |
|
if name.startswith("visual.trunk.blocks"): |
|
depth = int(name.split("visual.trunk.blocks.")[-1].split(".")[0]) |
|
if depth >= self.starting_depth: |
|
param.requires_grad = True |
|
|
|
|
|
for layer in range(self.starting_depth, len(blocks)): |
|
blocks[layer].attn.forward = types.MethodType( |
|
hooked_attention_timm_forward, blocks[layer].attn |
|
) |
|
blocks[layer].forward = types.MethodType( |
|
hooked_resblock_timm_forward, blocks[layer] |
|
) |
|
|
|
def _activate_att_pool_hooks(self, layer_index): |
|
|
|
|
|
for name, param in self.named_parameters(): |
|
param.requires_grad = False |
|
if name.startswith("visual.transformer.resblocks"): |
|
|
|
depth = int( |
|
name.split("visual.transformer.resblocks.")[-1].split(".")[0] |
|
) |
|
if depth >= self.starting_depth: |
|
param.requires_grad = True |
|
|
|
|
|
for layer in range(self.starting_depth, len(self.visual.transformer.resblocks)): |
|
self.visual.transformer.resblocks[layer].forward = types.MethodType( |
|
hooked_resblock_forward, self.visual.transformer.resblocks[layer] |
|
) |
|
|
|
self.visual.attn_pool.attn.forward = types.MethodType( |
|
hooked_torch_multi_head_attention_forward, self.visual.attn_pool.attn |
|
) |
|
|
|
def _activate_timm_attn_pool_hooks(self, layer_index): |
|
|
|
if ( |
|
not hasattr(self.visual.trunk, "attn_pool") |
|
or self.visual.trunk.attn_pool is None |
|
): |
|
raise ValueError("Attentional pooling not found in TimmModel") |
|
|
|
self.visual.trunk.attn_pool.forward = types.MethodType( |
|
hooked_attentional_pooler_timm_forward, self.visual.trunk.attn_pool |
|
) |
|
for block in self.visual.trunk.blocks: |
|
if hasattr(block, "attn"): |
|
block.attn.forward = types.MethodType( |
|
hooked_attention_forward, block.attn |
|
) |
|
|
|
|
|
for name, param in self.named_parameters(): |
|
param.requires_grad = False |
|
if name.startswith("visual.trunk.attn_pool"): |
|
param.requires_grad = True |
|
if name.startswith("visual.trunk.blocks"): |
|
|
|
depth = int(name.split("visual.trunk.blocks.")[-1].split(".")[0]) |
|
if depth >= self.starting_depth: |
|
param.requires_grad = True |
|
|
|
|
|
for layer in range(self.starting_depth, len(self.visual.trunk.blocks)): |
|
self.visual.trunk.blocks[layer].forward = types.MethodType( |
|
hooked_resblock_timm_forward, self.visual.trunk.blocks[layer] |
|
) |
|
|
|
self.visual.trunk.attn_pool.forward = types.MethodType( |
|
hooked_attentional_pooler_timm_forward, self.visual.trunk.attn_pool |
|
) |
|
|
|
def compute_legrad(self, text_embedding, image=None, apply_correction=True): |
|
if "clip" in self.model_type: |
|
return self.compute_legrad_clip(text_embedding, image) |
|
elif "siglip" in self.model_type: |
|
return self.compute_legrad_siglip( |
|
text_embedding, image, apply_correction=apply_correction |
|
) |
|
elif "coca" in self.model_type: |
|
return self.compute_legrad_coca(text_embedding, image) |
|
|
|
def compute_legrad_clip(self, text_embedding, image=None): |
|
num_prompts = text_embedding.shape[0] |
|
if image is not None: |
|
|
|
_ = self.encode_image(image) |
|
|
|
blocks_list = list(dict(self.visual.trunk.blocks.named_children()).values()) |
|
|
|
image_features_list = [] |
|
|
|
for layer in range(self.starting_depth, len(self.visual.trunk.blocks)): |
|
|
|
intermediate_feat = blocks_list[layer].feat_post_mlp |
|
|
|
intermediate_feat = intermediate_feat.mean(dim=1) |
|
intermediate_feat = self.visual.head( |
|
self.visual.trunk.norm(intermediate_feat) |
|
) |
|
intermediate_feat = F.normalize(intermediate_feat, dim=-1) |
|
image_features_list.append(intermediate_feat) |
|
|
|
num_tokens = blocks_list[-1].feat_post_mlp.shape[1] - 1 |
|
w = h = int(math.sqrt(num_tokens)) |
|
|
|
|
|
accum_expl_map = 0 |
|
for layer, (blk, img_feat) in enumerate( |
|
zip(blocks_list[self.starting_depth :], image_features_list) |
|
): |
|
self.visual.zero_grad() |
|
sim = text_embedding @ img_feat.transpose(-1, -2) |
|
one_hot = ( |
|
F.one_hot(torch.arange(0, num_prompts)) |
|
.float() |
|
.requires_grad_(True) |
|
.to(text_embedding.device) |
|
) |
|
one_hot = torch.sum(one_hot * sim) |
|
|
|
|
|
attn_map = blocks_list[self.starting_depth + layer].attn.attention_map |
|
|
|
|
|
|
|
grad = torch.autograd.grad( |
|
one_hot, [attn_map], retain_graph=True, create_graph=True |
|
)[0] |
|
|
|
grad = torch.clamp(grad, min=0.0) |
|
|
|
|
|
image_relevance = grad.mean(dim=1).mean(dim=1)[:, 1:] |
|
expl_map = rearrange(image_relevance, "b (w h) -> 1 b w h", w=w, h=h) |
|
|
|
expl_map = F.interpolate( |
|
expl_map, scale_factor=self.patch_size, mode="bilinear" |
|
) |
|
accum_expl_map += expl_map |
|
|
|
|
|
accum_expl_map = min_max(accum_expl_map) |
|
return accum_expl_map |
|
|
|
def compute_legrad_coca(self, text_embedding, image=None): |
|
if image is not None: |
|
_ = self.encode_image(image) |
|
|
|
blocks_list = list( |
|
dict(self.visual.transformer.resblocks.named_children()).values() |
|
) |
|
|
|
image_features_list = [] |
|
|
|
for layer in range(self.starting_depth, len(self.visual.transformer.resblocks)): |
|
intermediate_feat = self.visual.transformer.resblocks[ |
|
layer |
|
].feat_post_mlp |
|
intermediate_feat = intermediate_feat.permute( |
|
1, 0, 2 |
|
) |
|
image_features_list.append(intermediate_feat) |
|
|
|
num_tokens = blocks_list[-1].feat_post_mlp.shape[0] - 1 |
|
w = h = int(math.sqrt(num_tokens)) |
|
|
|
|
|
accum_expl_map = 0 |
|
for layer, (blk, img_feat) in enumerate( |
|
zip(blocks_list[self.starting_depth :], image_features_list) |
|
): |
|
self.visual.zero_grad() |
|
|
|
image_embedding = self.visual.attn_pool(img_feat)[ |
|
:, 0 |
|
] |
|
image_embedding = image_embedding @ self.visual.proj |
|
|
|
sim = text_embedding @ image_embedding.transpose(-1, -2) |
|
one_hot = torch.sum(sim) |
|
|
|
attn_map = ( |
|
self.visual.attn_pool.attn.attention_maps |
|
) |
|
|
|
|
|
grad = torch.autograd.grad( |
|
one_hot, [attn_map], retain_graph=True, create_graph=True |
|
)[ |
|
0 |
|
] |
|
grad = torch.clamp(grad, min=0.0) |
|
|
|
image_relevance = grad.mean(dim=0)[ |
|
0, 1: |
|
] |
|
expl_map = rearrange(image_relevance, "(w h) -> 1 1 w h", w=w, h=h) |
|
expl_map = F.interpolate( |
|
expl_map, scale_factor=self.patch_size, mode="bilinear" |
|
) |
|
accum_expl_map += expl_map |
|
|
|
|
|
accum_expl_map = (accum_expl_map - accum_expl_map.min()) / ( |
|
accum_expl_map.max() - accum_expl_map.min() |
|
) |
|
return accum_expl_map |
|
|
|
def _init_empty_embedding(self): |
|
if not hasattr(self, "empty_embedding"): |
|
|
|
_tok = open_clip.get_tokenizer(model_name="ViT-B-16-SigLIP") |
|
empty_text = _tok(["a photo of a"]).to(self.logit_scale.data.device) |
|
empty_embedding = self.encode_text(empty_text) |
|
empty_embedding = F.normalize(empty_embedding, dim=-1) |
|
self.empty_embedding = empty_embedding.t() |
|
|
|
def compute_legrad_siglip( |
|
self, |
|
text_embedding, |
|
image=None, |
|
apply_correction=True, |
|
correction_threshold=0.8, |
|
): |
|
|
|
blocks_list = list(dict(self.visual.trunk.blocks.named_children()).values()) |
|
if image is not None: |
|
_ = self.encode_image(image) |
|
|
|
image_features_list = [] |
|
for blk in blocks_list[self.starting_depth :]: |
|
intermediate_feat = blk.feat_post_mlp |
|
image_features_list.append(intermediate_feat) |
|
|
|
num_tokens = blocks_list[-1].feat_post_mlp.shape[1] |
|
w = h = int(math.sqrt(num_tokens)) |
|
|
|
if apply_correction: |
|
self._init_empty_embedding() |
|
accum_expl_map_empty = 0 |
|
|
|
accum_expl_map = 0 |
|
for layer, (blk, img_feat) in enumerate( |
|
zip(blocks_list[self.starting_depth :], image_features_list) |
|
): |
|
self.zero_grad() |
|
pooled_feat = self.visual.trunk.attn_pool(img_feat) |
|
pooled_feat = F.normalize(pooled_feat, dim=-1) |
|
|
|
sim = text_embedding @ pooled_feat.transpose(-1, -2) |
|
one_hot = torch.sum(sim) |
|
grad = torch.autograd.grad( |
|
one_hot, |
|
[self.visual.trunk.attn_pool.attn_probs], |
|
retain_graph=True, |
|
create_graph=True, |
|
)[0] |
|
grad = torch.clamp(grad, min=0.0) |
|
|
|
image_relevance = grad.mean(dim=1)[ |
|
:, 0 |
|
] |
|
expl_map = rearrange(image_relevance, "b (w h) -> b 1 w h", w=w, h=h) |
|
accum_expl_map += expl_map |
|
|
|
if apply_correction: |
|
|
|
sim_empty = pooled_feat @ self.empty_embedding |
|
one_hot_empty = torch.sum(sim_empty) |
|
grad_empty = torch.autograd.grad( |
|
one_hot_empty, |
|
[self.visual.trunk.attn_pool.attn_probs], |
|
retain_graph=True, |
|
create_graph=True, |
|
)[0] |
|
grad_empty = torch.clamp(grad_empty, min=0.0) |
|
|
|
image_relevance_empty = grad_empty.mean(dim=1)[ |
|
:, 0 |
|
] |
|
expl_map_empty = rearrange( |
|
image_relevance_empty, "b (w h) -> b 1 w h", w=w, h=h |
|
) |
|
accum_expl_map_empty += expl_map_empty |
|
|
|
if apply_correction: |
|
heatmap_empty = min_max(accum_expl_map_empty) |
|
accum_expl_map[heatmap_empty > correction_threshold] = 0 |
|
|
|
Res = min_max(accum_expl_map) |
|
Res = F.interpolate( |
|
Res, scale_factor=self.patch_size, mode="bilinear" |
|
) |
|
|
|
return Res |
|
|
|
|
|
class LePreprocess(nn.Module): |
|
""" |
|
Modify OpenCLIP preprocessing to accept arbitrary image size. |
|
""" |
|
|
|
def __init__(self, preprocess, image_size): |
|
super(LePreprocess, self).__init__() |
|
self.transform = Compose( |
|
[ |
|
Resize( |
|
(image_size, image_size), interpolation=InterpolationMode.BICUBIC |
|
), |
|
preprocess.transforms[-3], |
|
preprocess.transforms[-2], |
|
preprocess.transforms[-1], |
|
] |
|
) |
|
|
|
def forward(self, image): |
|
return self.transform(image) |
|
|