Haiyu Wu commited on
Commit
365d0f1
·
1 Parent(s): 8acfda2
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -57,16 +57,16 @@ def initialize_models():
57
  rep_drop_prob=0.,
58
  use_class_label=False)
59
  generator = generator.to(device)
60
- checkpoint = torch.load(generator_weights, map_location='cpu')
61
  generator.load_state_dict(checkpoint['model_vec2face'])
62
  generator.eval()
63
 
64
  id_model = iresnet("100", fp16=True).to(device)
65
- id_model.load_state_dict(torch.load(id_model_weights, map_location='cpu'))
66
  id_model.eval()
67
 
68
  quality_model = iresnet("100", fp16=True).to(device)
69
- quality_model.load_state_dict(torch.load(quality_model_weights, map_location='cpu'))
70
  quality_model.eval()
71
 
72
  pose_model = SixDRepNet(backbone_name='RepVGG-B1g2',
 
57
  rep_drop_prob=0.,
58
  use_class_label=False)
59
  generator = generator.to(device)
60
+ checkpoint = torch.load(generator_weights, map_location=device)
61
  generator.load_state_dict(checkpoint['model_vec2face'])
62
  generator.eval()
63
 
64
  id_model = iresnet("100", fp16=True).to(device)
65
+ id_model.load_state_dict(torch.load(id_model_weights, map_location=device))
66
  id_model.eval()
67
 
68
  quality_model = iresnet("100", fp16=True).to(device)
69
+ quality_model.load_state_dict(torch.load(quality_model_weights, map_location=device))
70
  quality_model.eval()
71
 
72
  pose_model = SixDRepNet(backbone_name='RepVGG-B1g2',