Haiyu Wu
commited on
Commit
·
283ee70
1
Parent(s):
89e5981
update
Browse files
pixel_generator/vec2face/model_vec2face.py
CHANGED
@@ -303,7 +303,7 @@ class MaskedGenerativeEncoderViT(nn.Module):
|
|
303 |
id_loss = torch.mean(1 - torch.cosine_similarity(out_feature, rep))
|
304 |
else:
|
305 |
distance = 1 - torch.cosine_similarity(out_feature, class_rep)
|
306 |
-
id_loss = torch.mean(torch.where(distance > 0.
|
307 |
quality = quality_model(image)
|
308 |
norm = torch.norm(quality, 2, 1, True)
|
309 |
q_loss = torch.where(norm < q_target, q_target - norm, torch.zeros_like(norm))
|
|
|
303 |
id_loss = torch.mean(1 - torch.cosine_similarity(out_feature, rep))
|
304 |
else:
|
305 |
distance = 1 - torch.cosine_similarity(out_feature, class_rep)
|
306 |
+
id_loss = torch.mean(torch.where(distance > 0.1, distance, torch.zeros_like(distance)))
|
307 |
quality = quality_model(image)
|
308 |
norm = torch.norm(quality, 2, 1, True)
|
309 |
q_loss = torch.where(norm < q_target, q_target - norm, torch.zeros_like(norm))
|