echen01 commited on
Commit
5b7158a
1 Parent(s): 2fec875

fix device

Browse files
criteria/id_loss.py CHANGED
@@ -17,7 +17,7 @@ class IDLoss(nn.Module):
17
  [4] https://github.com/eladrich/pixel2style2pixel
18
  """
19
 
20
- def __init__(self, model_path, official=False):
21
  """
22
  Arguments:
23
  model_path (str): Path to IR-SE50 model.
@@ -32,7 +32,7 @@ class IDLoss(nn.Module):
32
  input_size=112, num_layers=50, drop_ratio=0.6, mode="ir_se"
33
  )
34
 
35
- self.facenet.load_state_dict(torch.load(model_path))
36
  self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
37
  self.facenet.eval()
38
 
 
17
  [4] https://github.com/eladrich/pixel2style2pixel
18
  """
19
 
20
+ def __init__(self, model_path, official=False, device="cpu"):
21
  """
22
  Arguments:
23
  model_path (str): Path to IR-SE50 model.
 
32
  input_size=112, num_layers=50, drop_ratio=0.6, mode="ir_se"
33
  )
34
 
35
+ self.facenet.load_state_dict(torch.load(model_path, map_location=device))
36
  self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
37
  self.facenet.eval()
38
 
criteria/mask.py CHANGED
@@ -9,7 +9,7 @@ import numpy as np
9
 
10
 
11
  class Mask(nn.Module):
12
- def __init__(self):
13
  """
14
 
15
  | Class | Number | Class | Number |
@@ -41,7 +41,7 @@ class Mask(nn.Module):
41
  .requires_grad_(False)
42
  )
43
 
44
- ckpt = torch.load(paths_config.deeplab, map_location=global_config.device)
45
  state_dict = {
46
  k[7:]: v for k, v in ckpt["state_dict"].items() if "tracked" not in k
47
  }
 
9
 
10
 
11
  class Mask(nn.Module):
12
+ def __init__(self, device="cpu"):
13
  """
14
 
15
  | Class | Number | Class | Number |
 
41
  .requires_grad_(False)
42
  )
43
 
44
+ ckpt = torch.load(paths_config.deeplab, map_location=device)
45
  state_dict = {
46
  k[7:]: v for k, v in ckpt["state_dict"].items() if "tracked" not in k
47
  }
training/coaches/base_coach.py CHANGED
@@ -51,13 +51,14 @@ class BaseCoach:
51
  id_loss.IDLoss(
52
  paths_config.ir_se50,
53
  official=False,
 
54
  )
55
  .to(global_config.device)
56
  .eval()
57
  )
58
 
59
  if hyperparameters.use_mask:
60
- self.mask = mask.Mask()
61
 
62
  self.restart_training()
63
 
 
51
  id_loss.IDLoss(
52
  paths_config.ir_se50,
53
  official=False,
54
+ device=global_config.device
55
  )
56
  .to(global_config.device)
57
  .eval()
58
  )
59
 
60
  if hyperparameters.use_mask:
61
+ self.mask = mask.Mask(device=global_config.device)
62
 
63
  self.restart_training()
64