Spaces:
Runtime error
Runtime error
echen01
commited on
Commit
•
5b7158a
1
Parent(s):
2fec875
fix device
Browse files- criteria/id_loss.py +2 -2
- criteria/mask.py +2 -2
- training/coaches/base_coach.py +2 -1
criteria/id_loss.py
CHANGED
@@ -17,7 +17,7 @@ class IDLoss(nn.Module):
|
|
17 |
[4] https://github.com/eladrich/pixel2style2pixel
|
18 |
"""
|
19 |
|
20 |
-
def __init__(self, model_path, official=False):
|
21 |
"""
|
22 |
Arguments:
|
23 |
model_path (str): Path to IR-SE50 model.
|
@@ -32,7 +32,7 @@ class IDLoss(nn.Module):
|
|
32 |
input_size=112, num_layers=50, drop_ratio=0.6, mode="ir_se"
|
33 |
)
|
34 |
|
35 |
-
self.facenet.load_state_dict(torch.load(model_path))
|
36 |
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
|
37 |
self.facenet.eval()
|
38 |
|
|
|
17 |
[4] https://github.com/eladrich/pixel2style2pixel
|
18 |
"""
|
19 |
|
20 |
+
def __init__(self, model_path, official=False, device="cpu"):
|
21 |
"""
|
22 |
Arguments:
|
23 |
model_path (str): Path to IR-SE50 model.
|
|
|
32 |
input_size=112, num_layers=50, drop_ratio=0.6, mode="ir_se"
|
33 |
)
|
34 |
|
35 |
+
self.facenet.load_state_dict(torch.load(model_path, map_location=device))
|
36 |
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
|
37 |
self.facenet.eval()
|
38 |
|
criteria/mask.py
CHANGED
@@ -9,7 +9,7 @@ import numpy as np
|
|
9 |
|
10 |
|
11 |
class Mask(nn.Module):
|
12 |
-
def __init__(self):
|
13 |
"""
|
14 |
|
15 |
| Class | Number | Class | Number |
|
@@ -41,7 +41,7 @@ class Mask(nn.Module):
|
|
41 |
.requires_grad_(False)
|
42 |
)
|
43 |
|
44 |
-
ckpt = torch.load(paths_config.deeplab, map_location=
|
45 |
state_dict = {
|
46 |
k[7:]: v for k, v in ckpt["state_dict"].items() if "tracked" not in k
|
47 |
}
|
|
|
9 |
|
10 |
|
11 |
class Mask(nn.Module):
|
12 |
+
def __init__(self, device="cpu"):
|
13 |
"""
|
14 |
|
15 |
| Class | Number | Class | Number |
|
|
|
41 |
.requires_grad_(False)
|
42 |
)
|
43 |
|
44 |
+
ckpt = torch.load(paths_config.deeplab, map_location=device)
|
45 |
state_dict = {
|
46 |
k[7:]: v for k, v in ckpt["state_dict"].items() if "tracked" not in k
|
47 |
}
|
training/coaches/base_coach.py
CHANGED
@@ -51,13 +51,14 @@ class BaseCoach:
|
|
51 |
id_loss.IDLoss(
|
52 |
paths_config.ir_se50,
|
53 |
official=False,
|
|
|
54 |
)
|
55 |
.to(global_config.device)
|
56 |
.eval()
|
57 |
)
|
58 |
|
59 |
if hyperparameters.use_mask:
|
60 |
-
self.mask = mask.Mask()
|
61 |
|
62 |
self.restart_training()
|
63 |
|
|
|
51 |
id_loss.IDLoss(
|
52 |
paths_config.ir_se50,
|
53 |
official=False,
|
54 |
+
device=global_config.device
|
55 |
)
|
56 |
.to(global_config.device)
|
57 |
.eval()
|
58 |
)
|
59 |
|
60 |
if hyperparameters.use_mask:
|
61 |
+
self.mask = mask.Mask(device=global_config.device)
|
62 |
|
63 |
self.restart_training()
|
64 |
|