vshirasuna's picture
Move code to 3dgrid_vqgan folder
a4c759f
import argparse
def get_parser(parser=None):
if parser is None:
parser = argparse.ArgumentParser()
parser.add_argument("--data_root", type=str, required=False, default="")
parser.add_argument("--grid_path", type=str, required=False, default="")
parser.add_argument(
"--lr_start", type=float, default=3 * 1e-4, help="Initial lr value"
)
parser.add_argument(
"--max_epochs", type=int, required=False, default=1, help="max number of epochs"
)
parser.add_argument("--num_workers", type=int, default=0, required=False)
parser.add_argument("--dropout", type=float, default=0.1, required=False)
parser.add_argument("--n_batch", type=int, default=512, help="Batch size")
parser.add_argument("--dataset_name", type=str, required=False, default="sol")
parser.add_argument("--measure_name", type=str, required=False, default="measure")
parser.add_argument("--checkpoints_folder", type=str, required=True)
parser.add_argument("--model_path", type=str, default="./smi_ted/")
parser.add_argument("--ckpt_filename", type=str, default="smi_ted_Light_40.pt")
parser.add_argument("--restart_filename", type=str, default="")
parser.add_argument('--n_output', type=int, default=1)
parser.add_argument("--save_every_epoch", type=int, default=0)
parser.add_argument("--save_ckpt", type=int, default=1)
parser.add_argument("--start_seed", type=int, default=0)
parser.add_argument("--target_metric", type=str, default="rmse")
parser.add_argument("--loss_fn", type=str, default="mae")
return parser
def parse_args():
parser = get_parser()
args = parser.parse_args()
return args