File size: 747 Bytes
0b32ad6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
import torch
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("ckpt", help="eg. result/downstream/ExpName/states-100.ckpt")
parser.add_argument("field_string", help="eg. config.runner.total_steps")
args = parser.parse_args()
ckpt = torch.load(args.ckpt, map_location="cpu")
Args = ckpt["Args"]
Config = ckpt["Config"]
first_field, *remaining = args.field_string.split('.')
if first_field == 'args':
assert len(remaining) == 1
print(getattr(Args, remaining[0]))
elif first_field == 'config':
target_config = Config
for i, field_name in enumerate(remaining):
if i == len(remaining) - 1:
print(target_config[field_name])
else:
target_config = target_config[field_name]
|