Spaces:
Runtime error
Runtime error
# Based on a file from https://github.com/rinongal/StyleGAN-nada. | |
# ========================================================================================== | |
# | |
# Adobe’s modifications are Copyright 2023 Adobe Research. All rights reserved. | |
# Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit | |
# LICENSE.md. | |
# | |
# ========================================================================================== | |
import clip | |
import torch | |
from torchvision.transforms import transforms | |
import numpy as np | |
from PIL import Image | |
from expansion_utils.text_templates import imagenet_templates, part_templates | |
# TODO: get rid of unused stuff in this class | |
class CLIPLoss(torch.nn.Module): | |
def __init__(self, device, lambda_direction=1., lambda_patch=0., lambda_global=0., lambda_manifold=0., | |
lambda_texture=0., patch_loss_type='mae', direction_loss_type='cosine', clip_model='ViT-B/32'): | |
super(CLIPLoss, self).__init__() | |
self.device = device | |
self.model, clip_preprocess = clip.load(clip_model, device=self.device) | |
self.clip_preprocess = clip_preprocess | |
self.preprocess = transforms.Compose( | |
[transforms.Normalize(mean=[-1.0, -1.0, -1.0], | |
std=[2.0, 2.0, 2.0])] + # Un-normalize from [-1.0, 1.0] (GAN output) to [0, 1]. | |
clip_preprocess.transforms[:2] + # to match CLIP input scale assumptions | |
clip_preprocess.transforms[4:]) # + skip convert PIL to tensor | |
self.target_directions_cache = {} | |
self.patch_text_directions = None | |
self.patch_loss = DirectionLoss(patch_loss_type) | |
self.direction_loss = DirectionLoss(direction_loss_type) | |
self.patch_direction_loss = torch.nn.CosineSimilarity(dim=2) | |
self.lambda_global = lambda_global | |
self.lambda_patch = lambda_patch | |
self.lambda_direction = lambda_direction | |
self.lambda_manifold = lambda_manifold | |
self.lambda_texture = lambda_texture | |
self.src_text_features = None | |
self.target_text_features = None | |
self.angle_loss = torch.nn.L1Loss() | |
self.model_cnn, preprocess_cnn = clip.load("RN50", device=self.device) | |
self.preprocess_cnn = transforms.Compose( | |
[transforms.Normalize(mean=[-1.0, -1.0, -1.0], | |
std=[2.0, 2.0, 2.0])] + # Un-normalize from [-1.0, 1.0] (GAN output) to [0, 1]. | |
preprocess_cnn.transforms[:2] + # to match CLIP input scale assumptions | |
preprocess_cnn.transforms[4:]) # + skip convert PIL to tensor | |
self.model.requires_grad_(False) | |
self.model_cnn.requires_grad_(False) | |
self.texture_loss = torch.nn.MSELoss() | |
def tokenize(self, strings: list): | |
return clip.tokenize(strings).to(self.device) | |
def encode_text(self, tokens: list) -> torch.Tensor: | |
return self.model.encode_text(tokens) | |
def encode_images(self, images: torch.Tensor) -> torch.Tensor: | |
images = self.preprocess(images).to(self.device) | |
return self.model.encode_image(images) | |
def encode_images_with_cnn(self, images: torch.Tensor) -> torch.Tensor: | |
images = self.preprocess_cnn(images).to(self.device) | |
return self.model_cnn.encode_image(images) | |
def distance_with_templates(self, img: torch.Tensor, class_str: str, templates=imagenet_templates) -> torch.Tensor: | |
text_features = self.get_text_features(class_str, templates) | |
image_features = self.get_image_features(img) | |
similarity = image_features @ text_features.T | |
return 1. - similarity | |
def get_text_features(self, class_str: str, templates=imagenet_templates, norm: bool = True) -> torch.Tensor: | |
template_text = self.compose_text_with_templates(class_str, templates) | |
tokens = clip.tokenize(template_text).to(self.device) | |
text_features = self.encode_text(tokens).detach() | |
if norm: | |
text_features /= text_features.norm(dim=-1, keepdim=True) | |
return text_features | |
def get_image_features(self, img: torch.Tensor, norm: bool = True) -> torch.Tensor: | |
image_features = self.encode_images(img) | |
if norm: | |
image_features /= image_features.clone().norm(dim=-1, keepdim=True) | |
return image_features | |
def compute_text_direction(self, source_class: str, target_class: str) -> torch.Tensor: | |
with torch.no_grad(): | |
source_features = self.get_text_features(source_class) | |
target_features = self.get_text_features(target_class) | |
text_direction = (target_features - source_features).mean(axis=0, keepdim=True) | |
text_direction /= text_direction.norm(dim=-1, keepdim=True) | |
return text_direction | |
def compute_img2img_direction(self, source_images: torch.Tensor, target_images: list) -> torch.Tensor: | |
with torch.no_grad(): | |
src_encoding = self.get_image_features(source_images) | |
src_encoding = src_encoding.mean(dim=0, keepdim=True) | |
target_encodings = [] | |
for target_img in target_images: | |
preprocessed = self.clip_preprocess(Image.open(target_img)).unsqueeze(0).to(self.device) | |
encoding = self.model.encode_image(preprocessed) | |
encoding /= encoding.norm(dim=-1, keepdim=True) | |
target_encodings.append(encoding) | |
target_encoding = torch.cat(target_encodings, axis=0) | |
target_encoding = target_encoding.mean(dim=0, keepdim=True) | |
direction = target_encoding - src_encoding | |
direction /= direction.norm(dim=-1, keepdim=True) | |
return direction | |
def set_text_features(self, source_class: str, target_class: str) -> None: | |
source_features = self.get_text_features(source_class).mean(axis=0, keepdim=True) | |
self.src_text_features = source_features / source_features.norm(dim=-1, keepdim=True) | |
target_features = self.get_text_features(target_class).mean(axis=0, keepdim=True) | |
self.target_text_features = target_features / target_features.norm(dim=-1, keepdim=True) | |
def clip_angle_loss(self, src_img: torch.Tensor, source_class: str, target_img: torch.Tensor, | |
target_class: str) -> torch.Tensor: | |
if self.src_text_features is None: | |
self.set_text_features(source_class, target_class) | |
cos_text_angle = self.target_text_features @ self.src_text_features.T | |
text_angle = torch.acos(cos_text_angle) | |
src_img_features = self.get_image_features(src_img).unsqueeze(2) | |
target_img_features = self.get_image_features(target_img).unsqueeze(1) | |
cos_img_angle = torch.clamp(target_img_features @ src_img_features, min=-1.0, max=1.0) | |
img_angle = torch.acos(cos_img_angle) | |
text_angle = text_angle.unsqueeze(0).repeat(img_angle.size()[0], 1, 1) | |
cos_text_angle = cos_text_angle.unsqueeze(0).repeat(img_angle.size()[0], 1, 1) | |
return self.angle_loss(cos_img_angle, cos_text_angle) | |
def compose_text_with_templates(self, text: str, templates=imagenet_templates) -> list: | |
return [template.format(text) for template in templates] | |
def clip_directional_loss(self, src_img: torch.Tensor, source_classes: np.ndarray, target_img: torch.Tensor, | |
target_classes: np.ndarray) -> torch.Tensor: | |
target_directions = [] | |
for key in zip(source_classes, target_classes): | |
if key not in self.target_directions_cache.keys(): | |
new_direction = self.compute_text_direction(*key) | |
self.target_directions_cache[key] = new_direction | |
target_directions.append(self.target_directions_cache[key]) | |
target_directions = torch.cat(target_directions) | |
src_encoding = self.get_image_features(src_img) | |
target_encoding = self.get_image_features(target_img) | |
edit_direction = (target_encoding - src_encoding) | |
if edit_direction.sum() == 0: | |
target_encoding = self.get_image_features(target_img + 1e-6) | |
edit_direction = (target_encoding - src_encoding) | |
edit_direction /= (edit_direction.clone().norm(dim=-1, keepdim=True)) | |
return self.direction_loss(edit_direction, target_directions).sum() | |
def global_clip_loss(self, img: torch.Tensor, text) -> torch.Tensor: | |
if not isinstance(text, list): | |
text = [text] | |
tokens = clip.tokenize(text).to(self.device) | |
image = self.preprocess(img) | |
logits_per_image, _ = self.model(image, tokens) | |
return (1. - logits_per_image / 100).mean() | |
def random_patch_centers(self, img_shape, num_patches, size): | |
batch_size, channels, height, width = img_shape | |
half_size = size // 2 | |
patch_centers = np.concatenate( | |
[np.random.randint(half_size, width - half_size, size=(batch_size * num_patches, 1)), | |
np.random.randint(half_size, height - half_size, size=(batch_size * num_patches, 1))], axis=1) | |
return patch_centers | |
def generate_patches(self, img: torch.Tensor, patch_centers, size): | |
batch_size = img.shape[0] | |
num_patches = len(patch_centers) // batch_size | |
half_size = size // 2 | |
patches = [] | |
for batch_idx in range(batch_size): | |
for patch_idx in range(num_patches): | |
center_x = patch_centers[batch_idx * num_patches + patch_idx][0] | |
center_y = patch_centers[batch_idx * num_patches + patch_idx][1] | |
patch = img[batch_idx:batch_idx + 1, :, center_y - half_size:center_y + half_size, | |
center_x - half_size:center_x + half_size] | |
patches.append(patch) | |
patches = torch.cat(patches, axis=0) | |
return patches | |
def patch_scores(self, img: torch.Tensor, class_str: str, patch_centers, patch_size: int) -> torch.Tensor: | |
parts = self.compose_text_with_templates(class_str, part_templates) | |
tokens = clip.tokenize(parts).to(self.device) | |
text_features = self.encode_text(tokens).detach() | |
patches = self.generate_patches(img, patch_centers, patch_size) | |
image_features = self.get_image_features(patches) | |
similarity = image_features @ text_features.T | |
return similarity | |
def clip_patch_similarity(self, src_img: torch.Tensor, source_class: str, target_img: torch.Tensor, | |
target_class: str) -> torch.Tensor: | |
patch_size = 196 # TODO remove magic number | |
patch_centers = self.random_patch_centers(src_img.shape, 4, patch_size) # TODO remove magic number | |
src_scores = self.patch_scores(src_img, source_class, patch_centers, patch_size) | |
target_scores = self.patch_scores(target_img, target_class, patch_centers, patch_size) | |
return self.patch_loss(src_scores, target_scores) | |
def patch_directional_loss(self, src_img: torch.Tensor, source_class: str, target_img: torch.Tensor, | |
target_class: str) -> torch.Tensor: | |
if self.patch_text_directions is None: | |
src_part_classes = self.compose_text_with_templates(source_class, part_templates) | |
target_part_classes = self.compose_text_with_templates(target_class, part_templates) | |
parts_classes = list(zip(src_part_classes, target_part_classes)) | |
self.patch_text_directions = torch.cat( | |
[self.compute_text_direction(pair[0], pair[1]) for pair in parts_classes], dim=0) | |
patch_size = 510 # TODO remove magic numbers | |
patch_centers = self.random_patch_centers(src_img.shape, 1, patch_size) | |
patches = self.generate_patches(src_img, patch_centers, patch_size) | |
src_features = self.get_image_features(patches) | |
patches = self.generate_patches(target_img, patch_centers, patch_size) | |
target_features = self.get_image_features(patches) | |
edit_direction = (target_features - src_features) | |
edit_direction /= edit_direction.clone().norm(dim=-1, keepdim=True) | |
cosine_dists = 1. - self.patch_direction_loss(edit_direction.unsqueeze(1), | |
self.patch_text_directions.unsqueeze(0)) | |
patch_class_scores = cosine_dists * (edit_direction @ self.patch_text_directions.T).softmax(dim=-1) | |
return patch_class_scores.mean() | |
def cnn_feature_loss(self, src_img: torch.Tensor, target_img: torch.Tensor) -> torch.Tensor: | |
src_features = self.encode_images_with_cnn(src_img) | |
target_features = self.encode_images_with_cnn(target_img) | |
return self.texture_loss(src_features, target_features) | |
def forward(self, src_img: torch.Tensor, source_class: str, target_img: torch.Tensor, target_class: str, | |
texture_image: torch.Tensor = None): | |
clip_loss = 0.0 | |
if self.lambda_global: | |
clip_loss += self.lambda_global * self.global_clip_loss(target_img, [f"a {target_class}"]) | |
if self.lambda_patch: # IMO Same directional loss but run on patches | |
clip_loss += self.lambda_patch * self.patch_directional_loss(src_img, source_class, target_img, | |
target_class) | |
if self.lambda_direction: # The directional loss used in the paper | |
clip_loss += self.lambda_direction * self.clip_directional_loss(src_img, source_class, target_img, | |
target_class) | |
if self.lambda_manifold: # Compute angels of text and image directions and do L1 | |
clip_loss += self.lambda_manifold * self.clip_angle_loss(src_img, source_class, target_img, target_class) | |
if self.lambda_texture and (texture_image is not None): # L2 on features extracted by a CNN | |
clip_loss += self.lambda_texture * self.cnn_feature_loss(texture_image, target_img) | |
return clip_loss | |
class DirectionLoss(torch.nn.Module): | |
def __init__(self, loss_type='mse'): | |
super(DirectionLoss, self).__init__() | |
self.loss_type = loss_type | |
self.loss_func = { | |
'mse': torch.nn.MSELoss, | |
'cosine': torch.nn.CosineSimilarity, | |
'mae': torch.nn.L1Loss | |
}[loss_type]() | |
def forward(self, x, y): | |
if self.loss_type == "cosine": | |
return 1. - self.loss_func(x, y) | |
return self.loss_func(x, y) | |