Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision.models import vgg16 | |
from torchmetrics.functional import structural_similarity_index_measure | |
from facenet_pytorch import InceptionResnetV1 | |
from denormalize import denormalize_bin, denormalize_tr, denormalize_ar | |
class DF(nn.Module): | |
def __init__(self): | |
super(DF, self).__init__() | |
self.mse_weight = 0.25 | |
self.perceptual_weight = 0.25 | |
self.ssim_weight = 0.25 | |
self.idsim_weight = 0.25 | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.vgg = vgg16(pretrained=True).features[:16].to(device).eval() | |
self.facenet = InceptionResnetV1(pretrained='vggface2').to(device).eval() | |
for param in self.facenet.parameters(): | |
param.requires_grad = False # Freeze the model | |
self.cosloss = nn.CosineEmbeddingLoss() | |
def perceptual_loss(self, real, fake): | |
with torch.no_grad(): # VGG is frozen during training | |
real_features = self.vgg(real) | |
fake_features = self.vgg(fake) | |
return F.mse_loss(real_features, fake_features) | |
def idsimilarity(self, real, fake): | |
with torch.no_grad(): | |
# Extract embeddings | |
input_embed = self.facenet(real).to(device) | |
generated_embed = self.facenet(fake).to(device) | |
# Compute cosine similarity loss | |
target = torch.ones(input_embed.size(0)).to(real.device) # Target = 1 (maximize similarity) | |
return self.cosloss(input_embed, generated_embed, target) | |
def forward(self, r, f): | |
real = denormalize_bin(r) #[-1,1] to [0,1] | |
fake = denormalize_bin(f) | |
mse_loss = F.mse_loss(real, fake) | |
perceptual_loss = self.perceptual_loss(real, fake) | |
idsim_loss = self.idsimilarity(real, fake) | |
ssim = structural_similarity_index_measure(fake, real) | |
ssim_loss = 1 - ssim | |
id_si = 1 - idsim_loss | |
total_loss = (self.mse_weight * mse_loss) + (self.perceptual_weight * perceptual_loss) + (self.idsim_weight * idsim_loss) + (self.ssim_weight * ssim_loss) | |
components = { | |
"MSE Loss": mse_loss.item(), | |
"Perceptual Loss": perceptual_loss.item(), | |
"ID-SIM Loss": idsim_loss.item(), | |
"SSIM Loss": ssim_loss.item() | |
} | |
return total_loss, components | |