Spaces:
Running
Running
Sapphire-356
commited on
Commit
·
fb96f4f
1
Parent(s):
aa34300
CPU Version: fix torch.load
Browse files
joints_detectors/Alphapose/SPPE/src/main_fast_inference.py
CHANGED
@@ -27,7 +27,7 @@ class InferenNet(nn.Module):
|
|
27 |
model = createModel()
|
28 |
print('Loading pose model from {}'.format('joints_detectors/Alphapose/models/sppe/duc_se.pth'))
|
29 |
sys.stdout.flush()
|
30 |
-
model.load_state_dict(torch.load('joints_detectors/Alphapose/models/sppe/duc_se.pth'))
|
31 |
model.eval()
|
32 |
self.pyranet = model
|
33 |
|
@@ -54,7 +54,7 @@ class InferenNet_fast(nn.Module):
|
|
54 |
|
55 |
model = createModel()
|
56 |
print('Loading pose model from {}'.format('models/sppe/duc_se.pth'))
|
57 |
-
model.load_state_dict(torch.load('models/sppe/duc_se.pth'))
|
58 |
model.eval()
|
59 |
self.pyranet = model
|
60 |
|
|
|
27 |
model = createModel()
|
28 |
print('Loading pose model from {}'.format('joints_detectors/Alphapose/models/sppe/duc_se.pth'))
|
29 |
sys.stdout.flush()
|
30 |
+
model.load_state_dict(torch.load('joints_detectors/Alphapose/models/sppe/duc_se.pth', map_location=torch.device('cpu')))
|
31 |
model.eval()
|
32 |
self.pyranet = model
|
33 |
|
|
|
54 |
|
55 |
model = createModel()
|
56 |
print('Loading pose model from {}'.format('models/sppe/duc_se.pth'))
|
57 |
+
model.load_state_dict(torch.load('models/sppe/duc_se.pth', map_location=torch.device('cpu')))
|
58 |
model.eval()
|
59 |
self.pyranet = model
|
60 |
|
joints_detectors/Alphapose/SPPE/src/opt.py
CHANGED
@@ -96,7 +96,7 @@ parser.add_argument('--port', dest='port',
|
|
96 |
|
97 |
opt = parser.parse_args()
|
98 |
if opt.Continue:
|
99 |
-
opt = torch.load("../exp/{}/{}/option.pkl".format(opt.dataset, opt.expID))
|
100 |
opt.Continue = True
|
101 |
opt.nEpochs = 50
|
102 |
print("--- Continue ---")
|
|
|
96 |
|
97 |
opt = parser.parse_args()
|
98 |
if opt.Continue:
|
99 |
+
opt = torch.load("../exp/{}/{}/option.pkl".format(opt.dataset, opt.expID), map_location=torch.device('cpu'))
|
100 |
opt.Continue = True
|
101 |
opt.nEpochs = 50
|
102 |
print("--- Continue ---")
|
joints_detectors/Alphapose/train_sppe/src/train.py
CHANGED
@@ -112,7 +112,7 @@ def main():
|
|
112 |
m = createModel()
|
113 |
if opt.loadModel:
|
114 |
print('Loading Model from {}'.format(opt.loadModel))
|
115 |
-
m.load_state_dict(torch.load(opt.loadModel))
|
116 |
if not os.path.exists("../exp/{}/{}".format(opt.dataset, opt.expID)):
|
117 |
try:
|
118 |
os.mkdir("../exp/{}/{}".format(opt.dataset, opt.expID))
|
|
|
112 |
m = createModel()
|
113 |
if opt.loadModel:
|
114 |
print('Loading Model from {}'.format(opt.loadModel))
|
115 |
+
m.load_state_dict(torch.load(opt.loadModel, map_location=torch.device('cpu')))
|
116 |
if not os.path.exists("../exp/{}/{}".format(opt.dataset, opt.expID)):
|
117 |
try:
|
118 |
os.mkdir("../exp/{}/{}".format(opt.dataset, opt.expID))
|
tools/utils.py
CHANGED
@@ -138,7 +138,7 @@ def videopose_model_load():
|
|
138 |
# load trained model
|
139 |
from common.model import TemporalModel
|
140 |
chk_filename = main_path + '/checkpoint/pretrained_h36m_detectron_coco.bin'
|
141 |
-
checkpoint = torch.load(chk_filename, map_location=
|
142 |
model_pos = TemporalModel(17, 2, 17, filter_widths=[3, 3, 3, 3, 3], causal=False, dropout=False, channels=1024, dense=False)
|
143 |
model_pos = model_pos
|
144 |
model_pos.load_state_dict(checkpoint['model_pos'])
|
|
|
138 |
# load trained model
|
139 |
from common.model import TemporalModel
|
140 |
chk_filename = main_path + '/checkpoint/pretrained_h36m_detectron_coco.bin'
|
141 |
+
checkpoint = torch.load(chk_filename, map_location=torch.device('cpu')) # 把loc映射到storage
|
142 |
model_pos = TemporalModel(17, 2, 17, filter_widths=[3, 3, 3, 3, 3], causal=False, dropout=False, channels=1024, dense=False)
|
143 |
model_pos = model_pos
|
144 |
model_pos.load_state_dict(checkpoint['model_pos'])
|
videopose_PSTMO.py
CHANGED
@@ -103,7 +103,7 @@ def main(args):
|
|
103 |
model_dict = model['trans'].state_dict()
|
104 |
|
105 |
no_refine_path = "checkpoint/PSTMOS_no_refine_48_5137_in_the_wild.pth"
|
106 |
-
pre_dict = torch.load(no_refine_path)
|
107 |
for key, value in pre_dict.items():
|
108 |
name = key[7:]
|
109 |
model_dict[name] = pre_dict[key]
|
|
|
103 |
model_dict = model['trans'].state_dict()
|
104 |
|
105 |
no_refine_path = "checkpoint/PSTMOS_no_refine_48_5137_in_the_wild.pth"
|
106 |
+
pre_dict = torch.load(no_refine_path, map_location=torch.device('cpu'))
|
107 |
for key, value in pre_dict.items():
|
108 |
name = key[7:]
|
109 |
model_dict[name] = pre_dict[key]
|