Haiyu Wu
vec2face demo
918e8a0
import torch
import torch.nn as nn
import torch.nn.functional as F
from models import iresnet
from lpips.lpips import LPIPS
from pytorch_msssim import SSIM
def adopt_weight(weight, global_step, threshold=0, value=0.):
if global_step < threshold:
weight = value
return weight
def hinge_d_loss(logits_real, logits_fake):
loss_real = torch.mean(F.relu(1. - logits_real))
loss_fake = torch.mean(F.relu(1. + logits_fake))
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def mse_d_loss(logits_real, logits_fake):
loss_real = torch.mean((logits_real - 1.) ** 2)
loss_fake = torch.mean(logits_fake ** 2)
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def vanilla_d_loss(logits_real, logits_fake):
d_loss = 0.5 * (
torch.mean(torch.nn.functional.softplus(-logits_real)) +
torch.mean(torch.nn.functional.softplus(logits_fake)))
return d_loss
def create_fr_model(model_path, depth="100"):
model = iresnet(depth)
model.load_state_dict(torch.load(model_path))
# model.half()
return model
def downscale(img: torch.tensor):
half_size = img.shape[-1] // 8
img = F.interpolate(img, size=(half_size, half_size), mode='bicubic', align_corners=False)
return img
class VQLPIPSWithDiscriminator(nn.Module):
def __init__(self, disc_start=1000, disc_factor=1.0, disc_weight=1.0,
disc_conditional=False, disc_loss="mse", id_loss="mse",
fr_model="./models/arcface-r100-glint360k.pth"):
super().__init__()
assert disc_loss in ["hinge", "vanilla", "mse", "smooth"]
self.loss_name = disc_loss
self.perceptual_loss = LPIPS().eval()
self.discriminator_iter_start = disc_start
if disc_loss == "hinge":
self.disc_loss = hinge_d_loss
elif disc_loss == "vanilla":
self.disc_loss = vanilla_d_loss
elif disc_loss == "mse":
self.disc_loss = mse_d_loss
else:
raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
self.fr_model = create_fr_model(fr_model).eval()
if id_loss == "mse":
self.feature_loss = nn.MSELoss()
elif id_loss == "cosine":
self.feature_loss = nn.CosineSimilarity()
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.disc_conditional = disc_conditional
self.ssim_loss = SSIM(data_range=1, size_average=True, channel=3)
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
else:
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def forward(self, im_features, gt_indices, logits, gt_img, image, discriminator, emb_loss,
epoch, last_layer=None, cond=None, mask=None):
rec_loss = (image - gt_img) ** 2
if epoch >= 0:
gen_feature = self.fr_model(image)
feature_loss = torch.mean(1 - torch.cosine_similarity(im_features, gen_feature))
else:
feature_loss = 0
p_loss = self.perceptual_loss(image, gt_img) * 2
with torch.cuda.amp.autocast(enabled=False):
ssim_loss = 1 - self.ssim_loss((image.float() + 1) / 2, (gt_img + 1) / 2)
logits_fake = discriminator(image)
logits_real_d = discriminator(gt_img.detach())
logits_fake_d = discriminator(image.detach())
if mask is None:
token_loss = (logits[:, 1:, :] - gt_indices[:, 1:, :])
token_loss = torch.mean(token_loss ** 2)
else:
token_loss = torch.abs((logits[:, 1:, :] - gt_indices[:, 1:, :])) * mask[:, 1:, None]
token_loss = token_loss.sum() / mask[:, 1:].sum()
# token_loss = 0
nll_loss = torch.mean(rec_loss + p_loss) + \
ssim_loss + \
token_loss + feature_loss + emb_loss
# generator update
g_loss = -torch.mean(logits_fake)
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
disc_factor = adopt_weight(self.disc_factor, epoch, threshold=self.discriminator_iter_start)
ae_loss = nll_loss + d_weight * disc_factor * g_loss
# second pass for discriminator update
disc_factor = adopt_weight(self.disc_factor, epoch, threshold=self.discriminator_iter_start)
d_loss = disc_factor * self.disc_loss(logits_real_d, logits_fake_d)
return ae_loss, d_loss, token_loss, rec_loss, ssim_loss, p_loss, feature_loss