Update modules/frameworkeval.py
Browse files- modules/frameworkeval.py +5 -5
modules/frameworkeval.py
CHANGED
@@ -14,8 +14,8 @@ class DF(nn.Module):
|
|
14 |
self.ssim_weight = 0.25
|
15 |
self.idsim_weight = 0.25
|
16 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
-
self.vgg = vgg16(pretrained=True).features[:16].to(device).eval()
|
18 |
-
self.facenet = InceptionResnetV1(pretrained='vggface2').to(device).eval()
|
19 |
for param in self.facenet.parameters():
|
20 |
param.requires_grad = False # Freeze the model
|
21 |
self.cosloss = nn.CosineEmbeddingLoss()
|
@@ -29,10 +29,10 @@ class DF(nn.Module):
|
|
29 |
def idsimilarity(self, real, fake):
|
30 |
with torch.no_grad():
|
31 |
# Extract embeddings
|
32 |
-
input_embed = self.facenet(real).to(device)
|
33 |
-
generated_embed = self.facenet(fake).to(device)
|
34 |
# Compute cosine similarity loss
|
35 |
-
target = torch.ones(input_embed.size(0)).to(
|
36 |
return self.cosloss(input_embed, generated_embed, target)
|
37 |
|
38 |
def forward(self, r, f):
|
|
|
14 |
self.ssim_weight = 0.25
|
15 |
self.idsim_weight = 0.25
|
16 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
+
self.vgg = vgg16(pretrained=True).features[:16].to(self.device).eval()
|
18 |
+
self.facenet = InceptionResnetV1(pretrained='vggface2').to(self.device).eval()
|
19 |
for param in self.facenet.parameters():
|
20 |
param.requires_grad = False # Freeze the model
|
21 |
self.cosloss = nn.CosineEmbeddingLoss()
|
|
|
29 |
def idsimilarity(self, real, fake):
|
30 |
with torch.no_grad():
|
31 |
# Extract embeddings
|
32 |
+
input_embed = self.facenet(real).to(self.device)
|
33 |
+
generated_embed = self.facenet(fake).to(self.device)
|
34 |
# Compute cosine similarity loss
|
35 |
+
target = torch.ones(input_embed.size(0)).to(self.device) # Target = 1 (maximize similarity)
|
36 |
return self.cosloss(input_embed, generated_embed, target)
|
37 |
|
38 |
def forward(self, r, f):
|