Sapphire-356 commited on
Commit
fb96f4f
·
1 Parent(s): aa34300

CPU Version: fix torch.load

Browse files
joints_detectors/Alphapose/SPPE/src/main_fast_inference.py CHANGED
@@ -27,7 +27,7 @@ class InferenNet(nn.Module):
27
  model = createModel()
28
  print('Loading pose model from {}'.format('joints_detectors/Alphapose/models/sppe/duc_se.pth'))
29
  sys.stdout.flush()
30
- model.load_state_dict(torch.load('joints_detectors/Alphapose/models/sppe/duc_se.pth'))
31
  model.eval()
32
  self.pyranet = model
33
 
@@ -54,7 +54,7 @@ class InferenNet_fast(nn.Module):
54
 
55
  model = createModel()
56
  print('Loading pose model from {}'.format('models/sppe/duc_se.pth'))
57
- model.load_state_dict(torch.load('models/sppe/duc_se.pth'))
58
  model.eval()
59
  self.pyranet = model
60
 
 
27
  model = createModel()
28
  print('Loading pose model from {}'.format('joints_detectors/Alphapose/models/sppe/duc_se.pth'))
29
  sys.stdout.flush()
30
+ model.load_state_dict(torch.load('joints_detectors/Alphapose/models/sppe/duc_se.pth', map_location=torch.device('cpu')))
31
  model.eval()
32
  self.pyranet = model
33
 
 
54
 
55
  model = createModel()
56
  print('Loading pose model from {}'.format('models/sppe/duc_se.pth'))
57
+ model.load_state_dict(torch.load('models/sppe/duc_se.pth', map_location=torch.device('cpu')))
58
  model.eval()
59
  self.pyranet = model
60
 
joints_detectors/Alphapose/SPPE/src/opt.py CHANGED
@@ -96,7 +96,7 @@ parser.add_argument('--port', dest='port',
96
 
97
  opt = parser.parse_args()
98
  if opt.Continue:
99
- opt = torch.load("../exp/{}/{}/option.pkl".format(opt.dataset, opt.expID))
100
  opt.Continue = True
101
  opt.nEpochs = 50
102
  print("--- Continue ---")
 
96
 
97
  opt = parser.parse_args()
98
  if opt.Continue:
99
+ opt = torch.load("../exp/{}/{}/option.pkl".format(opt.dataset, opt.expID), map_location=torch.device('cpu'))
100
  opt.Continue = True
101
  opt.nEpochs = 50
102
  print("--- Continue ---")
joints_detectors/Alphapose/train_sppe/src/train.py CHANGED
@@ -112,7 +112,7 @@ def main():
112
  m = createModel()
113
  if opt.loadModel:
114
  print('Loading Model from {}'.format(opt.loadModel))
115
- m.load_state_dict(torch.load(opt.loadModel))
116
  if not os.path.exists("../exp/{}/{}".format(opt.dataset, opt.expID)):
117
  try:
118
  os.mkdir("../exp/{}/{}".format(opt.dataset, opt.expID))
 
112
  m = createModel()
113
  if opt.loadModel:
114
  print('Loading Model from {}'.format(opt.loadModel))
115
+ m.load_state_dict(torch.load(opt.loadModel, map_location=torch.device('cpu')))
116
  if not os.path.exists("../exp/{}/{}".format(opt.dataset, opt.expID)):
117
  try:
118
  os.mkdir("../exp/{}/{}".format(opt.dataset, opt.expID))
tools/utils.py CHANGED
@@ -138,7 +138,7 @@ def videopose_model_load():
138
  # load trained model
139
  from common.model import TemporalModel
140
  chk_filename = main_path + '/checkpoint/pretrained_h36m_detectron_coco.bin'
141
- checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage) # 把loc映射到storage
142
  model_pos = TemporalModel(17, 2, 17, filter_widths=[3, 3, 3, 3, 3], causal=False, dropout=False, channels=1024, dense=False)
143
  model_pos = model_pos
144
  model_pos.load_state_dict(checkpoint['model_pos'])
 
138
  # load trained model
139
  from common.model import TemporalModel
140
  chk_filename = main_path + '/checkpoint/pretrained_h36m_detectron_coco.bin'
141
+ checkpoint = torch.load(chk_filename, map_location=torch.device('cpu')) # 把loc映射到storage
142
  model_pos = TemporalModel(17, 2, 17, filter_widths=[3, 3, 3, 3, 3], causal=False, dropout=False, channels=1024, dense=False)
143
  model_pos = model_pos
144
  model_pos.load_state_dict(checkpoint['model_pos'])
videopose_PSTMO.py CHANGED
@@ -103,7 +103,7 @@ def main(args):
103
  model_dict = model['trans'].state_dict()
104
 
105
  no_refine_path = "checkpoint/PSTMOS_no_refine_48_5137_in_the_wild.pth"
106
- pre_dict = torch.load(no_refine_path)
107
  for key, value in pre_dict.items():
108
  name = key[7:]
109
  model_dict[name] = pre_dict[key]
 
103
  model_dict = model['trans'].state_dict()
104
 
105
  no_refine_path = "checkpoint/PSTMOS_no_refine_48_5137_in_the_wild.pth"
106
+ pre_dict = torch.load(no_refine_path, map_location=torch.device('cpu'))
107
  for key, value in pre_dict.items():
108
  name = key[7:]
109
  model_dict[name] = pre_dict[key]