席亚东
commited on
Commit
·
fc6a062
1
Parent(s):
05e5527
fix the bug in inferen.py
Browse files- inference.py +4 -4
inference.py
CHANGED
@@ -87,10 +87,10 @@ class Inference(object):
|
|
87 |
|
88 |
model_path = args.path
|
89 |
checkpoint = torch.load(model_path.replace("best.pt", "best_part_1.pt"))
|
90 |
-
checkpoint["model"].update(model_path.replace("best.pt", "best_part_2.pt"))
|
91 |
-
checkpoint["model"].update(model_path.replace("best.pt", "best_part_3.pt"))
|
92 |
torch.save(checkpoint, model_path)
|
93 |
-
|
94 |
state = torch.load(args.path, map_location=torch.device("cpu"))
|
95 |
cfg_args = eval(str(state["cfg"]))["model"]
|
96 |
del cfg_args["_name"]
|
@@ -178,4 +178,4 @@ class Inference(object):
|
|
178 |
score = hypo['score'] / math.log(2) # convert to base 2
|
179 |
tmp_res.append([detok_hypo_str, score])
|
180 |
final_results.append(tmp_res)
|
181 |
-
return final_results
|
|
|
87 |
|
88 |
model_path = args.path
|
89 |
checkpoint = torch.load(model_path.replace("best.pt", "best_part_1.pt"))
|
90 |
+
checkpoint["model"].update(torch.load(model_path.replace("best.pt", "best_part_2.pt")))
|
91 |
+
checkpoint["model"].update(torch.load(model_path.replace("best.pt", "best_part_3.pt")))
|
92 |
torch.save(checkpoint, model_path)
|
93 |
+
|
94 |
state = torch.load(args.path, map_location=torch.device("cpu"))
|
95 |
cfg_args = eval(str(state["cfg"]))["model"]
|
96 |
del cfg_args["_name"]
|
|
|
178 |
score = hypo['score'] / math.log(2) # convert to base 2
|
179 |
tmp_res.append([detok_hypo_str, score])
|
180 |
final_results.append(tmp_res)
|
181 |
+
return final_results
|