|
import argparse |
|
|
|
|
|
def get_parser(parser=None): |
|
if parser is None: |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument("--data_root", type=str, required=False, default="") |
|
parser.add_argument("--grid_path", type=str, required=False, default="") |
|
|
|
parser.add_argument( |
|
"--lr_start", type=float, default=3 * 1e-4, help="Initial lr value" |
|
) |
|
parser.add_argument( |
|
"--max_epochs", type=int, required=False, default=1, help="max number of epochs" |
|
) |
|
|
|
parser.add_argument("--num_workers", type=int, default=0, required=False) |
|
parser.add_argument("--dropout", type=float, default=0.1, required=False) |
|
parser.add_argument("--n_batch", type=int, default=512, help="Batch size") |
|
parser.add_argument("--dataset_name", type=str, required=False, default="sol") |
|
parser.add_argument("--measure_name", type=str, required=False, default="measure") |
|
parser.add_argument("--checkpoints_folder", type=str, required=True) |
|
parser.add_argument("--model_path", type=str, default="./smi_ted/") |
|
parser.add_argument("--ckpt_filename", type=str, default="smi_ted_Light_40.pt") |
|
parser.add_argument("--restart_filename", type=str, default="") |
|
parser.add_argument('--n_output', type=int, default=1) |
|
parser.add_argument("--save_every_epoch", type=int, default=0) |
|
parser.add_argument("--save_ckpt", type=int, default=1) |
|
parser.add_argument("--start_seed", type=int, default=0) |
|
parser.add_argument("--target_metric", type=str, default="rmse") |
|
parser.add_argument("--loss_fn", type=str, default="mae") |
|
|
|
return parser |
|
|
|
|
|
def parse_args(): |
|
parser = get_parser() |
|
args = parser.parse_args() |
|
return args |