|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torchvision.models import vgg16 |
|
from torchmetrics.functional import structural_similarity_index_measure |
|
from facenet_pytorch import InceptionResnetV1 |
|
from modules.denormalize import denormalize_bin, denormalize_tr, denormalize_ar |
|
|
|
class DF(nn.Module): |
|
def __init__(self): |
|
super(DF, self).__init__() |
|
self.mse_weight = 0.25 |
|
self.perceptual_weight = 0.25 |
|
self.ssim_weight = 0.25 |
|
self.idsim_weight = 0.25 |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.vgg = vgg16(pretrained=True).features[:16].to(self.device).eval() |
|
self.facenet = InceptionResnetV1(pretrained='vggface2').to(self.device).eval() |
|
for param in self.facenet.parameters(): |
|
param.requires_grad = False |
|
self.cosloss = nn.CosineEmbeddingLoss() |
|
|
|
def perceptual_loss(self, real, fake): |
|
with torch.no_grad(): |
|
real_features = self.vgg(real) |
|
fake_features = self.vgg(fake) |
|
return F.mse_loss(real_features, fake_features) |
|
|
|
def idsimilarity(self, real, fake): |
|
with torch.no_grad(): |
|
|
|
input_embed = self.facenet(real).to(self.device) |
|
generated_embed = self.facenet(fake).to(self.device) |
|
|
|
target = torch.ones(input_embed.size(0)).to(self.device) |
|
return self.cosloss(input_embed, generated_embed, target) |
|
|
|
def forward(self, r, f): |
|
real = denormalize_bin(r) |
|
fake = denormalize_bin(f) |
|
mse_loss = F.mse_loss(real, fake) |
|
perceptual_loss = self.perceptual_loss(real, fake) |
|
idsim_loss = self.idsimilarity(real, fake) |
|
ssim = structural_similarity_index_measure(fake, real) |
|
ssim_loss = 1 - ssim |
|
id_si = 1 - idsim_loss |
|
|
|
total_loss = (self.mse_weight * mse_loss) + (self.perceptual_weight * perceptual_loss) + (self.idsim_weight * idsim_loss) + (self.ssim_weight * ssim_loss) |
|
components = { |
|
"MSE Loss": mse_loss.item(), |
|
"Perceptual Loss": perceptual_loss.item(), |
|
"ID-SIM Loss": idsim_loss.item(), |
|
"SSIM Loss": ssim_loss.item() |
|
} |
|
|
|
return total_loss, components |
|
|