Jannat24's picture
Update modules/frameworkeval.py
2023edc verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import vgg16
from torchmetrics.functional import structural_similarity_index_measure
from facenet_pytorch import InceptionResnetV1
from modules.denormalize import denormalize_bin, denormalize_tr, denormalize_ar
class DF(nn.Module):
def __init__(self):
super(DF, self).__init__()
self.mse_weight = 0.25
self.perceptual_weight = 0.25
self.ssim_weight = 0.25
self.idsim_weight = 0.25
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.vgg = vgg16(pretrained=True).features[:16].to(self.device).eval()
self.facenet = InceptionResnetV1(pretrained='vggface2').to(self.device).eval()
for param in self.facenet.parameters():
param.requires_grad = False # Freeze the model
self.cosloss = nn.CosineEmbeddingLoss()
def perceptual_loss(self, real, fake):
with torch.no_grad(): # VGG is frozen during training
real_features = self.vgg(real)
fake_features = self.vgg(fake)
return F.mse_loss(real_features, fake_features)
def idsimilarity(self, real, fake):
with torch.no_grad():
# Extract embeddings
input_embed = self.facenet(real).to(self.device)
generated_embed = self.facenet(fake).to(self.device)
# Compute cosine similarity loss
target = torch.ones(input_embed.size(0)).to(self.device) # Target = 1 (maximize similarity)
return self.cosloss(input_embed, generated_embed, target)
def forward(self, r, f):
real = denormalize_bin(r) #[-1,1] to [0,1]
fake = denormalize_bin(f)
mse_loss = F.mse_loss(real, fake)
perceptual_loss = self.perceptual_loss(real, fake)
idsim_loss = self.idsimilarity(real, fake)
ssim = structural_similarity_index_measure(fake, real)
ssim_loss = 1 - ssim
id_si = 1 - idsim_loss
total_loss = (self.mse_weight * mse_loss) + (self.perceptual_weight * perceptual_loss) + (self.idsim_weight * idsim_loss) + (self.ssim_weight * ssim_loss)
components = {
"MSE Loss": mse_loss.item(),
"Perceptual Loss": perceptual_loss.item(),
"ID-SIM Loss": idsim_loss.item(),
"SSIM Loss": ssim_loss.item()
}
return total_loss, components