Haiyu Wu commited on
Commit
e1eebbb
·
1 Parent(s): c646cb8
pixel_generator/vec2face/model_vec2face.py CHANGED
@@ -303,7 +303,7 @@ class MaskedGenerativeEncoderViT(nn.Module):
303
  id_loss = torch.mean(1 - torch.cosine_similarity(out_feature, rep))
304
  else:
305
  distance = 1 - torch.cosine_similarity(out_feature, class_rep)
306
- id_loss = torch.mean(torch.where(distance > 0.1, distance, torch.zeros_like(distance)))
307
  quality = quality_model(image)
308
  norm = torch.norm(quality, 2, 1, True)
309
  q_loss = torch.where(norm < q_target, q_target - norm, torch.zeros_like(norm))
@@ -318,7 +318,7 @@ class MaskedGenerativeEncoderViT(nn.Module):
318
  yaw_loss = torch.abs(pose - torch.abs(pose_info[:, 1].clip(min=-90, max=90)))
319
  pose_loss = torch.mean(yaw_loss)
320
  q_loss = torch.mean(q_loss)
321
- if pose_loss > 5 or id_loss > 0.1 or q_loss > 1:
322
  i -= 1
323
  loss = id_loss * 100 + q_loss + pose_loss
324
  optm.zero_grad()
 
303
  id_loss = torch.mean(1 - torch.cosine_similarity(out_feature, rep))
304
  else:
305
  distance = 1 - torch.cosine_similarity(out_feature, class_rep)
306
+ id_loss = torch.mean(torch.where(distance > 0., distance, torch.zeros_like(distance)))
307
  quality = quality_model(image)
308
  norm = torch.norm(quality, 2, 1, True)
309
  q_loss = torch.where(norm < q_target, q_target - norm, torch.zeros_like(norm))
 
318
  yaw_loss = torch.abs(pose - torch.abs(pose_info[:, 1].clip(min=-90, max=90)))
319
  pose_loss = torch.mean(yaw_loss)
320
  q_loss = torch.mean(q_loss)
321
+ if pose_loss > 5 or id_loss > 0.3 or q_loss > 1:
322
  i -= 1
323
  loss = id_loss * 100 + q_loss + pose_loss
324
  optm.zero_grad()