glenn-jocher commited on
Commit
4b0e374
·
unverified ·
2 Parent(s): 4dd75b3 4f6c0cf

Merge pull request #59 from Lijun-Yu/master

Browse files
Files changed (1) hide show
  1. hubconf.py +6 -2
hubconf.py CHANGED
@@ -6,6 +6,9 @@ Usage:
6
  """
7
 
8
  dependencies = ['torch', 'yaml']
 
 
 
9
  import torch
10
 
11
  from models.yolo import Model
@@ -24,11 +27,12 @@ def create(name, pretrained, channels, classes):
24
  Returns:
25
  pytorch model
26
  """
27
- model = Model('models/%s.yaml' % name, channels, classes)
 
28
  if pretrained:
29
  ckpt = '%s.pt' % name # checkpoint filename
30
  google_utils.attempt_download(ckpt) # download if not found locally
31
- state_dict = torch.load(ckpt)['model'].state_dict()
32
  state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].numel() == v.numel()} # filter
33
  model.load_state_dict(state_dict, strict=False) # load
34
  return model
 
6
  """
7
 
8
  dependencies = ['torch', 'yaml']
9
+
10
+ import os
11
+
12
  import torch
13
 
14
  from models.yolo import Model
 
27
  Returns:
28
  pytorch model
29
  """
30
+ config = os.path.join(os.path.dirname(__file__), 'models', '%s.yaml' % name) # model.yaml path
31
+ model = Model(config, channels, classes)
32
  if pretrained:
33
  ckpt = '%s.pt' % name # checkpoint filename
34
  google_utils.attempt_download(ckpt) # download if not found locally
35
+ state_dict = torch.load(ckpt, map_location=torch.device('cpu'))['model'].state_dict()
36
  state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].numel() == v.numel()} # filter
37
  model.load_state_dict(state_dict, strict=False) # load
38
  return model