File size: 1,692 Bytes
9123ba9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import argparse


def get_parser(parser=None):
    if parser is None:
        parser = argparse.ArgumentParser()

    parser.add_argument("--data_root", type=str, required=False, default="")
    parser.add_argument("--grid_path", type=str, required=False, default="")

    parser.add_argument(
        "--lr_start", type=float, default=3 * 1e-4, help="Initial lr value"
    )
    parser.add_argument(
        "--max_epochs", type=int, required=False, default=1, help="max number of epochs"
    )
    
    parser.add_argument("--num_workers", type=int, default=0, required=False)
    parser.add_argument("--dropout", type=float, default=0.1, required=False)
    parser.add_argument("--n_batch", type=int, default=512, help="Batch size")
    parser.add_argument("--dataset_name", type=str, required=False, default="sol")
    parser.add_argument("--measure_name", type=str, required=False, default="measure")
    parser.add_argument("--checkpoints_folder", type=str, required=True)
    parser.add_argument("--model_path", type=str, default="./smi_ted/")
    parser.add_argument("--ckpt_filename", type=str, default="smi_ted_Light_40.pt")
    parser.add_argument("--restart_filename", type=str, default="")
    parser.add_argument('--n_output', type=int, default=1)
    parser.add_argument("--save_every_epoch", type=int, default=0)
    parser.add_argument("--save_ckpt", type=int, default=1)
    parser.add_argument("--start_seed", type=int, default=0)
    parser.add_argument("--target_metric", type=str, default="rmse")
    parser.add_argument("--loss_fn", type=str, default="mae")

    return parser


def parse_args():
    parser = get_parser()
    args = parser.parse_args()
    return args