Haiyu Wu
commited on
Commit
·
365d0f1
1
Parent(s):
8acfda2
update
Browse files
app.py
CHANGED
@@ -57,16 +57,16 @@ def initialize_models():
|
|
57 |
rep_drop_prob=0.,
|
58 |
use_class_label=False)
|
59 |
generator = generator.to(device)
|
60 |
-
checkpoint = torch.load(generator_weights, map_location=
|
61 |
generator.load_state_dict(checkpoint['model_vec2face'])
|
62 |
generator.eval()
|
63 |
|
64 |
id_model = iresnet("100", fp16=True).to(device)
|
65 |
-
id_model.load_state_dict(torch.load(id_model_weights, map_location=
|
66 |
id_model.eval()
|
67 |
|
68 |
quality_model = iresnet("100", fp16=True).to(device)
|
69 |
-
quality_model.load_state_dict(torch.load(quality_model_weights, map_location=
|
70 |
quality_model.eval()
|
71 |
|
72 |
pose_model = SixDRepNet(backbone_name='RepVGG-B1g2',
|
|
|
57 |
rep_drop_prob=0.,
|
58 |
use_class_label=False)
|
59 |
generator = generator.to(device)
|
60 |
+
checkpoint = torch.load(generator_weights, map_location=device)
|
61 |
generator.load_state_dict(checkpoint['model_vec2face'])
|
62 |
generator.eval()
|
63 |
|
64 |
id_model = iresnet("100", fp16=True).to(device)
|
65 |
+
id_model.load_state_dict(torch.load(id_model_weights, map_location=device))
|
66 |
id_model.eval()
|
67 |
|
68 |
quality_model = iresnet("100", fp16=True).to(device)
|
69 |
+
quality_model.load_state_dict(torch.load(quality_model_weights, map_location=device))
|
70 |
quality_model.eval()
|
71 |
|
72 |
pose_model = SixDRepNet(backbone_name='RepVGG-B1g2',
|