philipp-zettl commited on
Commit
1005357
·
verified ·
1 Parent(s): 93014b5

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +2 -2
model.py CHANGED
@@ -167,7 +167,7 @@ class MultiHeadClassification(nn.Module):
167
  Returns:
168
  None
169
  """
170
- model = torch.load(path)
171
  if head_name in self.heads:
172
  num_classes = model['weight'].shape[0]
173
  self.heads[head_name].load_state_dict(model)
@@ -209,7 +209,7 @@ class MultiHeadClassification(nn.Module):
209
  Args:
210
  path (str): Path to the file
211
  """
212
- self.load_state_dict(torch.load(path))
213
  self.to(self.torch_dtype).to(self.device)
214
 
215
  def save_backbone(self, path):
 
167
  Returns:
168
  None
169
  """
170
+ model = torch.load(path, map_location=self.device)
171
  if head_name in self.heads:
172
  num_classes = model['weight'].shape[0]
173
  self.heads[head_name].load_state_dict(model)
 
209
  Args:
210
  path (str): Path to the file
211
  """
212
+ self.load_state_dict(torch.load(path, map_location=self.device))
213
  self.to(self.torch_dtype).to(self.device)
214
 
215
  def save_backbone(self, path):