Jannat24 commited on
Commit
2023edc
·
verified ·
1 Parent(s): ec0d70d

Update modules/frameworkeval.py

Browse files
Files changed (1) hide show
  1. modules/frameworkeval.py +5 -5
modules/frameworkeval.py CHANGED
@@ -14,8 +14,8 @@ class DF(nn.Module):
14
  self.ssim_weight = 0.25
15
  self.idsim_weight = 0.25
16
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
- self.vgg = vgg16(pretrained=True).features[:16].to(device).eval()
18
- self.facenet = InceptionResnetV1(pretrained='vggface2').to(device).eval()
19
  for param in self.facenet.parameters():
20
  param.requires_grad = False # Freeze the model
21
  self.cosloss = nn.CosineEmbeddingLoss()
@@ -29,10 +29,10 @@ class DF(nn.Module):
29
  def idsimilarity(self, real, fake):
30
  with torch.no_grad():
31
  # Extract embeddings
32
- input_embed = self.facenet(real).to(device)
33
- generated_embed = self.facenet(fake).to(device)
34
  # Compute cosine similarity loss
35
- target = torch.ones(input_embed.size(0)).to(real.device) # Target = 1 (maximize similarity)
36
  return self.cosloss(input_embed, generated_embed, target)
37
 
38
  def forward(self, r, f):
 
14
  self.ssim_weight = 0.25
15
  self.idsim_weight = 0.25
16
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ self.vgg = vgg16(pretrained=True).features[:16].to(self.device).eval()
18
+ self.facenet = InceptionResnetV1(pretrained='vggface2').to(self.device).eval()
19
  for param in self.facenet.parameters():
20
  param.requires_grad = False # Freeze the model
21
  self.cosloss = nn.CosineEmbeddingLoss()
 
29
  def idsimilarity(self, real, fake):
30
  with torch.no_grad():
31
  # Extract embeddings
32
+ input_embed = self.facenet(real).to(self.device)
33
+ generated_embed = self.facenet(fake).to(self.device)
34
  # Compute cosine similarity loss
35
+ target = torch.ones(input_embed.size(0)).to(self.device) # Target = 1 (maximize similarity)
36
  return self.cosloss(input_embed, generated_embed, target)
37
 
38
  def forward(self, r, f):