|
import torch |
|
import wandb |
|
import cv2 |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from facenet_pytorch import MTCNN |
|
from torchvision import transforms |
|
from dreamsim import dreamsim |
|
from einops import rearrange |
|
import kornia.augmentation as K |
|
import lpips |
|
|
|
from pretrained_models.arcface import Backbone |
|
from utils.vis_utils import add_text_to_image |
|
from utils.utils import extract_faces_and_landmarks |
|
import clip |
|
|
|
|
|
class Loss(): |
|
""" |
|
General purpose loss class. |
|
Mainly handles dtype and visualize_every_k. |
|
keeps current iteration of loss, mainly for visualization purposes. |
|
""" |
|
def __init__(self, visualize_every_k=-1, dtype=torch.float32, accelerator=None, **kwargs): |
|
self.visualize_every_k = visualize_every_k |
|
self.iteration = -1 |
|
self.dtype=dtype |
|
self.accelerator = accelerator |
|
|
|
def __call__(self, **kwargs): |
|
self.iteration += 1 |
|
return self.forward(**kwargs) |
|
|
|
|
|
class L1Loss(Loss): |
|
""" |
|
Simple L1 loss between predicted_pixel_values and pixel_values |
|
|
|
Args: |
|
predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder. |
|
encoder_pixel_values (torch.Tesnor): The input image to the encoder |
|
""" |
|
def forward( |
|
self, |
|
predict: torch.Tensor, |
|
target: torch.Tensor, |
|
**kwargs |
|
) -> torch.Tensor: |
|
return F.l1_loss(predict, target, reduction="mean") |
|
|
|
|
|
class DreamSIMLoss(Loss): |
|
"""DreamSIM loss between predicted_pixel_values and pixel_values. |
|
DreamSIM is similar to LPIPS (https://dreamsim-nights.github.io/) but is trained on more human defined similarity dataset |
|
DreamSIM expects an RGB image of size 224x224 and values between 0 and 1. So we need to normalize the input images to 0-1 range and resize them to 224x224. |
|
Args: |
|
predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder. |
|
encoder_pixel_values (torch.Tesnor): The input image to the encoder |
|
""" |
|
def __init__(self, device: str='cuda:0', **kwargs): |
|
super().__init__(**kwargs) |
|
self.model, _ = dreamsim(pretrained=True, device=device) |
|
self.model.to(dtype=self.dtype, device=device) |
|
self.model = self.accelerator.prepare(self.model) |
|
self.transforms = transforms.Compose([ |
|
transforms.Lambda(lambda x: (x + 1) / 2), |
|
transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC)]) |
|
|
|
def forward( |
|
self, |
|
predicted_pixel_values: torch.Tensor, |
|
encoder_pixel_values: torch.Tensor, |
|
**kwargs, |
|
) -> torch.Tensor: |
|
predicted_pixel_values.to(dtype=self.dtype) |
|
encoder_pixel_values.to(dtype=self.dtype) |
|
return self.model(self.transforms(predicted_pixel_values), self.transforms(encoder_pixel_values)).mean() |
|
|
|
|
|
class LPIPSLoss(Loss): |
|
"""LPIPS loss between predicted_pixel_values and pixel_values. |
|
Args: |
|
predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder. |
|
encoder_pixel_values (torch.Tesnor): The input image to the encoder |
|
""" |
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
self.model = lpips.LPIPS(net='vgg') |
|
self.model.to(dtype=self.dtype, device=self.accelerator.device) |
|
self.model = self.accelerator.prepare(self.model) |
|
|
|
def forward(self, predict, target, **kwargs): |
|
predict.to(dtype=self.dtype) |
|
target.to(dtype=self.dtype) |
|
return self.model(predict, target).mean() |
|
|
|
|
|
class LCMVisualization(Loss): |
|
"""Dummy loss used to visualize the LCM outputs |
|
Args: |
|
predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder. |
|
pixel_values (torch.Tensor): The input image to the decoder |
|
encoder_pixel_values (torch.Tesnor): The input image to the encoder |
|
""" |
|
def forward( |
|
self, |
|
predicted_pixel_values: torch.Tensor, |
|
pixel_values: torch.Tensor, |
|
encoder_pixel_values: torch.Tensor, |
|
timesteps: torch.Tensor, |
|
**kwargs, |
|
) -> None: |
|
if self.visualize_every_k > 0 and self.iteration % self.visualize_every_k == 0: |
|
predicted_pixel_values = rearrange(predicted_pixel_values, "n c h w -> (n h) w c").detach().cpu().numpy() |
|
pixel_values = rearrange(pixel_values, "n c h w -> (n h) w c").detach().cpu().numpy() |
|
encoder_pixel_values = rearrange(encoder_pixel_values, "n c h w -> (n h) w c").detach().cpu().numpy() |
|
image = np.hstack([encoder_pixel_values, pixel_values, predicted_pixel_values]) |
|
for tracker in self.accelerator.trackers: |
|
if tracker.name == 'wandb': |
|
tracker.log({"TrainVisualization": wandb.Image(image, caption=f"Encoder Input Image, Decoder Input Image, Predicted LCM Image. Timesteps {timesteps.cpu().tolist()}")}) |
|
return torch.tensor(0.0) |
|
|
|
|
|
class L2Loss(Loss): |
|
""" |
|
Regular diffusion loss between predicted noise and target noise. |
|
|
|
Args: |
|
predicted_noise (torch.Tensor): noise predicted by the diffusion model |
|
target_noise (torch.Tensor): actual noise added to the image. |
|
""" |
|
def forward( |
|
self, |
|
predict: torch.Tensor, |
|
target: torch.Tensor, |
|
weights: torch.Tensor = None, |
|
**kwargs |
|
) -> torch.Tensor: |
|
if weights is not None: |
|
loss = (predict.float() - target.float()).pow(2) * weights |
|
return loss.mean() |
|
return F.mse_loss(predict.float(), target.float(), reduction="mean") |
|
|
|
|
|
class HuberLoss(Loss): |
|
"""Huber loss between predicted_pixel_values and pixel_values. |
|
Args: |
|
predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder. |
|
encoder_pixel_values (torch.Tesnor): The input image to the encoder |
|
""" |
|
def __init__(self, huber_c=0.001, **kwargs): |
|
super().__init__(**kwargs) |
|
self.huber_c = huber_c |
|
|
|
def forward( |
|
self, |
|
predict: torch.Tensor, |
|
target: torch.Tensor, |
|
weights: torch.Tensor = None, |
|
**kwargs |
|
) -> torch.Tensor: |
|
loss = torch.sqrt((predict.float() - target.float()) ** 2 + self.huber_c**2) - self.huber_c |
|
if weights is not None: |
|
return (loss * weights).mean() |
|
return loss.mean() |
|
|
|
|
|
class WeightedNoiseLoss(Loss): |
|
""" |
|
Weighted diffusion loss between predicted noise and target noise. |
|
|
|
Args: |
|
predicted_noise (torch.Tensor): noise predicted by the diffusion model |
|
target_noise (torch.Tensor): actual noise added to the image. |
|
loss_batch_weights (torch.Tensor): weighting for each batch item. Can be used to e.g. zero-out loss for InstantID training if keypoint extraction fails. |
|
""" |
|
def forward( |
|
self, |
|
predict: torch.Tensor, |
|
target: torch.Tensor, |
|
weights, |
|
**kwargs |
|
) -> torch.Tensor: |
|
return F.mse_loss(predict.float() * weights, target.float() * weights, reduction="mean") |
|
|
|
|
|
class IDLoss(Loss): |
|
""" |
|
Use pretrained facenet model to extract features from the face of the predicted image and target image. |
|
Facenet expects 112x112 images, so we crop the face using MTCNN and resize it to 112x112. |
|
Then we use the cosine similarity between the features to calculate the loss. (The cosine similarity is 1 - cosine distance). |
|
Also notice that the outputs of facenet are normalized so the dot product is the same as cosine distance. |
|
""" |
|
def __init__(self, pretrained_arcface_path: str, skip_not_found=True, **kwargs): |
|
super().__init__(**kwargs) |
|
assert pretrained_arcface_path is not None, "please pass `pretrained_arcface_path` in the losses config. You can download the pretrained model from "\ |
|
"https://drive.google.com/file/d/1KW7bjndL3QG3sxBbZxreGHigcCCpsDgn/view?usp=sharing" |
|
self.mtcnn = MTCNN(device=self.accelerator.device) |
|
self.mtcnn.forward = self.mtcnn.detect |
|
self.facenet_input_size = 112 |
|
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') |
|
self.facenet.load_state_dict(torch.load(pretrained_arcface_path)) |
|
self.face_pool = torch.nn.AdaptiveAvgPool2d((self.facenet_input_size, self.facenet_input_size)) |
|
self.facenet.requires_grad_(False) |
|
self.facenet.eval() |
|
self.facenet.to(device=self.accelerator.device, dtype=self.dtype) |
|
self.face_pool.to(device=self.accelerator.device, dtype=self.dtype) |
|
self.visualization_resize = transforms.Resize((self.facenet_input_size, self.facenet_input_size), interpolation=transforms.InterpolationMode.BICUBIC) |
|
self.reference_facial_points = np.array([[38.29459953, 51.69630051], |
|
[72.53179932, 51.50139999], |
|
[56.02519989, 71.73660278], |
|
[41.54930115, 92.3655014], |
|
[70.72990036, 92.20410156] |
|
]) |
|
self.facenet, self.face_pool, self.mtcnn = self.accelerator.prepare(self.facenet, self.face_pool, self.mtcnn) |
|
|
|
self.skip_not_found = skip_not_found |
|
|
|
def extract_feats(self, x: torch.Tensor): |
|
""" |
|
Extract features from the face of the image using facenet model. |
|
""" |
|
x = self.face_pool(x) |
|
x_feats = self.facenet(x) |
|
|
|
return x_feats |
|
|
|
def forward( |
|
self, |
|
predicted_pixel_values: torch.Tensor, |
|
encoder_pixel_values: torch.Tensor, |
|
timesteps: torch.Tensor, |
|
**kwargs |
|
): |
|
encoder_pixel_values = encoder_pixel_values.to(dtype=self.dtype) |
|
predicted_pixel_values = predicted_pixel_values.to(dtype=self.dtype) |
|
|
|
predicted_pixel_values_face, predicted_invalid_indices = extract_faces_and_landmarks(predicted_pixel_values, mtcnn=self.mtcnn) |
|
with torch.no_grad(): |
|
encoder_pixel_values_face, source_invalid_indices = extract_faces_and_landmarks(encoder_pixel_values, mtcnn=self.mtcnn) |
|
|
|
if self.skip_not_found: |
|
valid_indices = [] |
|
for i in range(predicted_pixel_values.shape[0]): |
|
if i not in predicted_invalid_indices and i not in source_invalid_indices: |
|
valid_indices.append(i) |
|
else: |
|
valid_indices = list(range(predicted_pixel_values)) |
|
|
|
valid_indices = torch.tensor(valid_indices).to(device=predicted_pixel_values.device) |
|
|
|
if len(valid_indices) == 0: |
|
loss = (predicted_pixel_values_face * 0.0).mean() |
|
if self.visualize_every_k > 0 and self.iteration % self.visualize_every_k == 0: |
|
self.visualize(predicted_pixel_values, encoder_pixel_values, predicted_pixel_values_face, encoder_pixel_values_face, timesteps, valid_indices, loss) |
|
return loss |
|
|
|
with torch.no_grad(): |
|
pixel_values_feats = self.extract_feats(encoder_pixel_values_face[valid_indices]) |
|
|
|
predicted_pixel_values_feats = self.extract_feats(predicted_pixel_values_face[valid_indices]) |
|
loss = 1 - torch.einsum("bi,bi->b", pixel_values_feats, predicted_pixel_values_feats) |
|
|
|
if self.visualize_every_k > 0 and self.iteration % self.visualize_every_k == 0: |
|
self.visualize(predicted_pixel_values, encoder_pixel_values, predicted_pixel_values_face, encoder_pixel_values_face, timesteps, valid_indices, loss) |
|
return loss.mean() |
|
|
|
def visualize( |
|
self, |
|
predicted_pixel_values: torch.Tensor, |
|
encoder_pixel_values: torch.Tensor, |
|
predicted_pixel_values_face: torch.Tensor, |
|
encoder_pixel_values_face: torch.Tensor, |
|
timesteps: torch.Tensor, |
|
valid_indices: torch.Tensor, |
|
loss: torch.Tensor, |
|
) -> None: |
|
small_predicted_pixel_values = (rearrange(self.visualization_resize(predicted_pixel_values), "n c h w -> (n h) w c").detach().cpu().numpy()) |
|
small_pixle_values = rearrange(self.visualization_resize(encoder_pixel_values), "n c h w -> (n h) w c").detach().cpu().numpy() |
|
small_predicted_pixel_values_face = rearrange(self.visualization_resize(predicted_pixel_values_face), "n c h w -> (n h) w c").detach().cpu().numpy() |
|
small_pixle_values_face = rearrange(self.visualization_resize(encoder_pixel_values_face), "n c h w -> (n h) w c").detach().cpu().numpy() |
|
|
|
small_predicted_pixel_values = add_text_to_image(((small_predicted_pixel_values * 0.5 + 0.5) * 255).astype(np.uint8), "Pred Images", add_below=False) |
|
small_pixle_values = add_text_to_image(((small_pixle_values * 0.5 + 0.5) * 255).astype(np.uint8), "Target Images", add_below=False) |
|
small_predicted_pixel_values_face = add_text_to_image(((small_predicted_pixel_values_face * 0.5 + 0.5) * 255).astype(np.uint8), "Pred Faces", add_below=False) |
|
small_pixle_values_face = add_text_to_image(((small_pixle_values_face * 0.5 + 0.5) * 255).astype(np.uint8), "Target Faces", add_below=False) |
|
|
|
|
|
final_image = np.hstack([small_predicted_pixel_values, small_pixle_values, small_predicted_pixel_values_face, small_pixle_values_face]) |
|
for tracker in self.accelerator.trackers: |
|
if tracker.name == 'wandb': |
|
tracker.log({"IDLoss Visualization": wandb.Image(final_image, caption=f"loss: {loss.cpu().tolist()} timesteps: {timesteps.cpu().tolist()}, valid_indices: {valid_indices.cpu().tolist()}")}) |
|
|
|
|
|
class ImageAugmentations(torch.nn.Module): |
|
|
|
def __init__(self, output_size, augmentations_number, p=0.7): |
|
super().__init__() |
|
self.output_size = output_size |
|
self.augmentations_number = augmentations_number |
|
|
|
self.augmentations = torch.nn.Sequential( |
|
K.RandomAffine(degrees=15, translate=0.1, p=p, padding_mode="border"), |
|
K.RandomPerspective(0.7, p=p), |
|
) |
|
|
|
self.avg_pool = torch.nn.AdaptiveAvgPool2d((self.output_size, self.output_size)) |
|
|
|
self.device = None |
|
|
|
def forward(self, input): |
|
"""Extents the input batch with augmentations |
|
If the input is consists of images [I1, I2] the extended augmented output |
|
will be [I1_resized, I2_resized, I1_aug1, I2_aug1, I1_aug2, I2_aug2 ...] |
|
Args: |
|
input ([type]): input batch of shape [batch, C, H, W] |
|
Returns: |
|
updated batch: of shape [batch * augmentations_number, C, H, W] |
|
""" |
|
|
|
|
|
resized_images = self.avg_pool(input) |
|
resized_images = torch.tile(resized_images, dims=(self.augmentations_number, 1, 1, 1)) |
|
|
|
batch_size = input.shape[0] |
|
|
|
non_augmented_batch = resized_images[:batch_size] |
|
augmented_batch = self.augmentations(resized_images[batch_size:]) |
|
updated_batch = torch.cat([non_augmented_batch, augmented_batch], dim=0) |
|
|
|
return updated_batch |
|
|
|
|
|
class CLIPLoss(Loss): |
|
def __init__(self, augmentations_number: int = 4, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
self.clip_model, clip_preprocess = clip.load("ViT-B/16", device=self.accelerator.device, jit=False) |
|
|
|
self.clip_model.device = None |
|
|
|
self.clip_model.eval().requires_grad_(False) |
|
|
|
self.preprocess = transforms.Compose([transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0])] + |
|
clip_preprocess.transforms[:2] + |
|
clip_preprocess.transforms[4:]) |
|
|
|
self.clip_size = self.clip_model.visual.input_resolution |
|
|
|
self.clip_normalize = transforms.Normalize( |
|
mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711] |
|
) |
|
|
|
self.image_augmentations = ImageAugmentations(output_size=self.clip_size, |
|
augmentations_number=augmentations_number) |
|
|
|
self.clip_model, self.image_augmentations = self.accelerator.prepare(self.clip_model, self.image_augmentations) |
|
|
|
def forward(self, decoder_prompts, predicted_pixel_values: torch.Tensor, **kwargs) -> torch.Tensor: |
|
|
|
if not isinstance(decoder_prompts, list): |
|
decoder_prompts = [decoder_prompts] |
|
|
|
tokens = clip.tokenize(decoder_prompts).to(predicted_pixel_values.device) |
|
image = self.preprocess(predicted_pixel_values) |
|
|
|
logits_per_image, _ = self.clip_model(image, tokens) |
|
|
|
logits_per_image = torch.diagonal(logits_per_image) |
|
|
|
return (1. - logits_per_image / 100).mean() |
|
|
|
|
|
class DINOLoss(Loss): |
|
def __init__( |
|
self, |
|
dino_model, |
|
dino_preprocess, |
|
output_hidden_states: bool = False, |
|
center_momentum: float = 0.9, |
|
student_temp: float = 0.1, |
|
teacher_temp: float = 0.04, |
|
warmup_teacher_temp: float = 0.04, |
|
warmup_teacher_temp_epochs: int = 30, |
|
**kwargs): |
|
super().__init__(**kwargs) |
|
|
|
self.dino_model = dino_model |
|
self.output_hidden_states = output_hidden_states |
|
self.rescale_factor = dino_preprocess.rescale_factor |
|
|
|
|
|
self.preprocess = transforms.Compose( |
|
[ |
|
transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0]), |
|
transforms.Resize(size=256), |
|
transforms.CenterCrop(size=(224, 224)), |
|
transforms.Normalize(mean=dino_preprocess.image_mean, std=dino_preprocess.image_std) |
|
] |
|
) |
|
|
|
self.student_temp = student_temp |
|
self.teacher_temp = teacher_temp |
|
self.center_momentum = center_momentum |
|
self.center = torch.zeros(1, 257, 1024).to(self.accelerator.device, dtype=self.dtype) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.dino_model = self.accelerator.prepare(self.dino_model) |
|
|
|
def forward( |
|
self, |
|
target: torch.Tensor, |
|
predict: torch.Tensor, |
|
weights: torch.Tensor = None, |
|
**kwargs) -> torch.Tensor: |
|
|
|
predict = self.preprocess(predict) |
|
target = self.preprocess(target) |
|
|
|
encoder_input = torch.cat([target, predict]).to(self.dino_model.device, dtype=self.dino_model.dtype) |
|
|
|
if self.output_hidden_states: |
|
raise ValueError("Output hidden states not supported for DINO loss.") |
|
image_enc_hidden_states = self.dino_model(encoder_input, output_hidden_states=True).hidden_states[-2] |
|
else: |
|
image_enc_hidden_states = self.dino_model(encoder_input).last_hidden_state |
|
|
|
teacher_output, student_output = image_enc_hidden_states.chunk(2, dim=0) |
|
|
|
student_out = student_output.float() / self.student_temp |
|
|
|
|
|
|
|
temp = self.teacher_temp |
|
teacher_out = F.softmax((teacher_output.float() - self.center) / temp, dim=-1) |
|
teacher_out = teacher_out.detach() |
|
|
|
loss = torch.sum(-teacher_out * F.log_softmax(student_out, dim=-1), dim=-1, keepdim=True) |
|
|
|
|
|
if weights is not None: |
|
loss = loss * weights |
|
return loss.mean() |
|
return loss.mean() |
|
|
|
@torch.no_grad() |
|
def update_center(self, teacher_output): |
|
""" |
|
Update center used for teacher output. |
|
""" |
|
batch_center = torch.sum(teacher_output, dim=0, keepdim=True) |
|
self.accelerator.reduce(batch_center, reduction="sum") |
|
batch_center = batch_center / (len(teacher_output) * self.accelerator.num_processes) |
|
|
|
|
|
self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum) |
|
|