Vincentqyw
fix: roma
8b973ee
"""
Main file to launch training and testing experiments.
"""
import yaml
import os
import argparse
import numpy as np
import torch
from .config.project_config import Config as cfg
from .train import train_net
from .export import export_predictions, export_homograpy_adaptation
# Pytorch configurations
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = True
def load_config(config_path):
"""Load configurations from a given yaml file."""
# Check file exists
if not os.path.exists(config_path):
raise ValueError("[Error] The provided config path is not valid.")
# Load the configuration
with open(config_path, "r") as f:
config = yaml.safe_load(f)
return config
def update_config(path, model_cfg=None, dataset_cfg=None):
"""Update configuration file from the resume path."""
# Check we need to update or completely override.
model_cfg = {} if model_cfg is None else model_cfg
dataset_cfg = {} if dataset_cfg is None else dataset_cfg
# Load saved configs
with open(os.path.join(path, "model_cfg.yaml"), "r") as f:
model_cfg_saved = yaml.safe_load(f)
model_cfg.update(model_cfg_saved)
with open(os.path.join(path, "dataset_cfg.yaml"), "r") as f:
dataset_cfg_saved = yaml.safe_load(f)
dataset_cfg.update(dataset_cfg_saved)
# Update the saved yaml file
if not model_cfg == model_cfg_saved:
with open(os.path.join(path, "model_cfg.yaml"), "w") as f:
yaml.dump(model_cfg, f)
if not dataset_cfg == dataset_cfg_saved:
with open(os.path.join(path, "dataset_cfg.yaml"), "w") as f:
yaml.dump(dataset_cfg, f)
return model_cfg, dataset_cfg
def record_config(model_cfg, dataset_cfg, output_path):
"""Record dataset config to the log path."""
# Record model config
with open(os.path.join(output_path, "model_cfg.yaml"), "w") as f:
yaml.safe_dump(model_cfg, f)
# Record dataset config
with open(os.path.join(output_path, "dataset_cfg.yaml"), "w") as f:
yaml.safe_dump(dataset_cfg, f)
def train(args, dataset_cfg, model_cfg, output_path):
"""Training function."""
# Update model config from the resume path (only in resume mode)
if args.resume:
if os.path.realpath(output_path) != os.path.realpath(args.resume_path):
record_config(model_cfg, dataset_cfg, output_path)
# First time, then write the config file to the output path
else:
record_config(model_cfg, dataset_cfg, output_path)
# Launch the training
train_net(args, dataset_cfg, model_cfg, output_path)
def export(
args,
dataset_cfg,
model_cfg,
output_path,
export_dataset_mode=None,
device=torch.device("cuda"),
):
"""Export function."""
# Choose between normal predictions export or homography adaptation
if dataset_cfg.get("homography_adaptation") is not None:
print("[Info] Export predictions with homography adaptation.")
export_homograpy_adaptation(
args, dataset_cfg, model_cfg, output_path, export_dataset_mode, device
)
else:
print("[Info] Export predictions normally.")
export_predictions(
args, dataset_cfg, model_cfg, output_path, export_dataset_mode
)
def main(
args, dataset_cfg, model_cfg, export_dataset_mode=None, device=torch.device("cuda")
):
"""Main function."""
# Make the output path
output_path = os.path.join(cfg.EXP_PATH, args.exp_name)
if args.mode == "train":
if not os.path.exists(output_path):
os.makedirs(output_path)
print("[Info] Training mode")
print("\t Output path: %s" % output_path)
train(args, dataset_cfg, model_cfg, output_path)
elif args.mode == "export":
# Different output_path in export mode
output_path = os.path.join(cfg.export_dataroot, args.exp_name)
print("[Info] Export mode")
print("\t Output path: %s" % output_path)
export(
args,
dataset_cfg,
model_cfg,
output_path,
export_dataset_mode,
device=device,
)
else:
raise ValueError("[Error]: Unknown mode: " + args.mode)
def set_random_seed(seed):
np.random.seed(seed)
torch.manual_seed(seed)
if __name__ == "__main__":
# Parse input arguments
parser = argparse.ArgumentParser()
parser.add_argument(
"--mode", type=str, default="train", help="'train' or 'export'."
)
parser.add_argument(
"--dataset_config", type=str, default=None, help="Path to the dataset config."
)
parser.add_argument(
"--model_config", type=str, default=None, help="Path to the model config."
)
parser.add_argument("--exp_name", type=str, default="exp", help="Experiment name.")
parser.add_argument(
"--resume",
action="store_true",
default=False,
help="Load a previously trained model.",
)
parser.add_argument(
"--pretrained",
action="store_true",
default=False,
help="Start training from a pre-trained model.",
)
parser.add_argument(
"--resume_path", default=None, help="Path from which to resume training."
)
parser.add_argument(
"--pretrained_path", default=None, help="Path to the pre-trained model."
)
parser.add_argument(
"--checkpoint_name", default=None, help="Name of the checkpoint to use."
)
parser.add_argument(
"--export_dataset_mode", default=None, help="'train' or 'test'."
)
parser.add_argument(
"--export_batch_size", default=4, type=int, help="Export batch size."
)
args = parser.parse_args()
# Check if GPU is available
# Get the model
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
# Check if dataset config and model config is given.
if (
((args.dataset_config is None) or (args.model_config is None))
and (not args.resume)
and (args.mode == "train")
):
raise ValueError(
"[Error] The dataset config and model config should be given in non-resume mode"
)
# If resume, check if the resume path has been given
if args.resume and (args.resume_path is None):
raise ValueError("[Error] Missing resume path.")
# [Training] Load the config file.
if args.mode == "train" and (not args.resume):
# Check the pretrained checkpoint_path exists
if args.pretrained:
checkpoint_folder = args.resume_path
checkpoint_path = os.path.join(args.pretrained_path, args.checkpoint_name)
if not os.path.exists(checkpoint_path):
raise ValueError("[Error] Missing checkpoint: " + checkpoint_path)
dataset_cfg = load_config(args.dataset_config)
model_cfg = load_config(args.model_config)
# [resume Training, Test, Export] Load the config file.
elif (args.mode == "train" and args.resume) or (args.mode == "export"):
# Check checkpoint path exists
checkpoint_folder = args.resume_path
checkpoint_path = os.path.join(args.resume_path, args.checkpoint_name)
if not os.path.exists(checkpoint_path):
raise ValueError("[Error] Missing checkpoint: " + checkpoint_path)
# Load model_cfg from checkpoint folder if not provided
if args.model_config is None:
print("[Info] No model config provided. Loading from checkpoint folder.")
model_cfg_path = os.path.join(checkpoint_folder, "model_cfg.yaml")
if not os.path.exists(model_cfg_path):
raise ValueError("[Error] Missing model config in checkpoint path.")
model_cfg = load_config(model_cfg_path)
else:
model_cfg = load_config(args.model_config)
# Load dataset_cfg from checkpoint folder if not provided
if args.dataset_config is None:
print("[Info] No dataset config provided. Loading from checkpoint folder.")
dataset_cfg_path = os.path.join(checkpoint_folder, "dataset_cfg.yaml")
if not os.path.exists(dataset_cfg_path):
raise ValueError("[Error] Missing dataset config in checkpoint path.")
dataset_cfg = load_config(dataset_cfg_path)
else:
dataset_cfg = load_config(args.dataset_config)
# Check the --export_dataset_mode flag
if (args.mode == "export") and (args.export_dataset_mode is None):
raise ValueError("[Error] Empty --export_dataset_mode flag.")
else:
raise ValueError("[Error] Unknown mode: " + args.mode)
# Set the random seed
seed = dataset_cfg.get("random_seed", 0)
set_random_seed(seed)
main(
args,
dataset_cfg,
model_cfg,
export_dataset_mode=args.export_dataset_mode,
device=device,
)