Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn.functional as F | |
import einops | |
from torchvision import transforms | |
from tqdm import tqdm | |
import PIL | |
from concept_attention.binary_segmentation_baselines.clip_text_span.prs_hook import hook_prs_logger | |
from concept_attention.binary_segmentation_baselines.clip_text_span.utils.factory import create_model_and_transforms, get_tokenizer | |
from concept_attention.binary_segmentation_baselines.clip_text_span.utils.openai_templates import OPENAI_IMAGENET_TEMPLATES | |
from concept_attention.segmentation import SegmentationAbstractClass | |
class CLIPTextSpanSegmentationModel(SegmentationAbstractClass): | |
def __init__( | |
self, | |
model_name='ViT-H-14', | |
pretrained='laion2b_s32b_b79k', | |
device='cuda:3' | |
): | |
self.device = device | |
# Load up the clip model and the tokenizer | |
self.clip_model, _, preprocess = create_model_and_transforms( | |
model_name, pretrained=pretrained | |
) | |
self.clip_model.to(device) | |
self.clip_model.eval() | |
context_length = self.clip_model.context_length | |
vocab_size = self.clip_model.vocab_size | |
self.tokenizer = get_tokenizer(model_name) | |
self.image_transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
]) | |
self.prs = hook_prs_logger(self.clip_model, device) | |
def generate_clip_vectors_for_concepts(self, concepts: list[str]): | |
""" | |
Produces a set of clip vectors for each concept by averaging a set of | |
templates. | |
""" | |
autocast = torch.cuda.amp.autocast | |
with torch.no_grad(), autocast(): | |
zeroshot_weights = [] | |
for classname in tqdm(concepts): | |
texts = [template(classname) for template in OPENAI_IMAGENET_TEMPLATES] | |
texts = self.tokenizer(texts).to(self.device) # tokenize | |
class_embeddings = self.clip_model.encode_text(texts) | |
class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) | |
class_embedding /= class_embedding.norm() | |
zeroshot_weights.append(class_embedding) | |
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(self.device) | |
return zeroshot_weights | |
def segment_individual_image(self, image: torch.Tensor, concepts: list[str], caption: str, **kwargs): | |
# Apply transform to image | |
if isinstance(image, PIL.Image.Image): | |
image = self.image_transform(image) | |
else: | |
image = transforms.ToPILImage()(image) | |
image = self.image_transform(image) | |
if len(image.shape) == 3: | |
image = image.unsqueeze(0) | |
image_size = image.shape[-1] | |
# Compute CLIP vectors for each text concept | |
concept_vectors = self.generate_clip_vectors_for_concepts(concepts) | |
concept_vectors = concept_vectors.detach().cpu() | |
# Create the encodings for the image | |
self.prs.reinit() | |
representation = self.clip_model.encode_image( | |
image.to(self.device), attn_method="head", normalize=False | |
) | |
attentions, _ = self.prs.finalize(representation) | |
representation = representation.detach().cpu() | |
attentions = attentions.detach().cpu() # [b, l, n, h, d] | |
# chosen_class = (representation @ concept_vectors).argmax(axis=1) | |
attentions_collapse = attentions[:, :, 1:].sum(axis=(1, 3)) | |
concept_heatmaps = ( | |
attentions_collapse @ concept_vectors | |
) # [b, n, classes] | |
# Now reshape the heatmaps | |
patches = image_size // self.clip_model.visual.patch_size[0] | |
concept_heatmaps = einops.rearrange( | |
concept_heatmaps, | |
"1 (h w) concepts -> concepts h w", | |
h=patches, w=patches | |
) | |
# NOTE: none corresponds to reconstructed image which does not exist for this model | |
return concept_heatmaps, None | |