import os import argparse from mmengine import Config def create_deeplabv3plus_config(model_config_path, dataset_config_path, num_class, work_dir, save_dir, batch_size, max_iters, val_interval): cfg = Config.fromfile(model_config_path) dataset_cfg = Config.fromfile(dataset_config_path) cfg.merge_from_dict(dataset_cfg) # Set crop size cfg.crop_size = (512, 512) cfg.model.data_preprocessor.size = cfg.crop_size # Configure normalization cfg.norm_cfg = dict(type='BN', requires_grad=True) cfg.model.backbone.norm_cfg = cfg.norm_cfg cfg.model.decode_head.norm_cfg = cfg.norm_cfg cfg.model.auxiliary_head.norm_cfg = cfg.norm_cfg cfg.model.decode_head.num_classes = num_class cfg.model.auxiliary_head.num_classes = num_class cfg.train_dataloader.batch_size = batch_size # Set training configurations cfg.train_cfg.max_iters = max_iters cfg.train_cfg.val_interval = val_interval cfg.default_hooks.logger.interval = 100 cfg.default_hooks.checkpoint.interval = 2500 cfg.default_hooks.checkpoint.max_keep_ckpts = 1 cfg.default_hooks.checkpoint.save_best = 'mIoU' cfg['randomness'] = dict(seed=0) # Set work directory cfg.save_dir = save_dir name = os.path.basename(dataset_config_path).split('_')[0] + "_" + os.path.dirname(model_config_path).split(os.sep)[1] cfg.work_dir = os.path.join(work_dir,name) os.makedirs(cfg.work_dir, exist_ok=True) save_config_file = os.path.join(save_dir, f"{name}.py") cfg.dump(save_config_file) print(f"Configuration saved to: {save_config_file}") def create_knet_config(model_config_path, dataset_config_path, num_class, work_dir, save_dir, batch_size, max_iters, val_interval): cfg = Config.fromfile(model_config_path) dataset_cfg = Config.fromfile(dataset_config_path) cfg.merge_from_dict(dataset_cfg) cfg.norm_cfg = dict(type='BN', requires_grad=True) cfg.model.data_preprocessor.size = cfg.crop_size cfg.model.decode_head.kernel_generate_head.num_classes = num_class cfg.model.auxiliary_head.num_classes = num_class cfg.train_dataloader.batch_size = batch_size cfg.work_dir = work_dir cfg.train_cfg.max_iters = max_iters cfg.train_cfg.val_interval = val_interval cfg.default_hooks.logger.interval = 100 cfg.default_hooks.checkpoint.interval = 2500 cfg.default_hooks.checkpoint.max_keep_ckpts = 1 cfg.default_hooks.checkpoint.save_best = 'mIoU' cfg['randomness'] = dict(seed=0) cfg.save_dir = save_dir name = os.path.basename(dataset_config_path).split('_')[0] + "_" + os.path.dirname(model_config_path).split(os.sep)[1] cfg.work_dir = os.path.join(work_dir, name) os.makedirs(cfg.work_dir, exist_ok=True) save_config_file = os.path.join(save_dir, f"{name}.py") cfg.dump(save_config_file) print(f"Configuration saved to: {save_config_file}") def create_mask2former_config(model_config_path, dataset_config_path, num_class, work_dir, save_dir, batch_size, max_iters, val_interval): cfg = Config.fromfile(model_config_path) dataset_cfg = Config.fromfile(dataset_config_path) cfg.merge_from_dict(dataset_cfg) # Set crop size cfg.crop_size = (512, 512) cfg.model.data_preprocessor.size = cfg.crop_size # Configure normalization cfg.norm_cfg = dict(type='BN', requires_grad=True) cfg.model.decode_head.num_classes = num_class cfg.model.decode_head.loss_cls.class_weight = [1.0] * num_class + [0.1] cfg.train_dataloader.batch_size = batch_size # Set training configurations cfg.train_cfg.max_iters = max_iters cfg.train_cfg.val_interval = val_interval cfg.default_hooks.logger.interval = 100 cfg.default_hooks.checkpoint.interval = 2500 cfg.default_hooks.checkpoint.max_keep_ckpts = 1 cfg.default_hooks.checkpoint.save_best = 'mIoU' cfg['randomness'] = dict(seed=0) # Set work directory cfg.save_dir = save_dir name = os.path.basename(dataset_config_path).split('_')[0] + "_" + os.path.dirname(model_config_path).split(os.sep)[1] cfg.work_dir = os.path.join(work_dir,name) os.makedirs(cfg.work_dir, exist_ok=True) save_config_file = os.path.join(save_dir, f"{name}.py") cfg.dump(save_config_file) print(f"Configuration saved to: {save_config_file}") def create_segformer_config(model_config_path, dataset_config_path, num_class, work_dir, save_dir, batch_size, max_iters, val_interval): cfg = Config.fromfile(model_config_path) dataset_cfg = Config.fromfile(dataset_config_path) cfg.merge_from_dict(dataset_cfg) # Configure normalization cfg.norm_cfg = dict(type='BN', requires_grad=True) cfg.model.data_preprocessor.size = cfg.crop_size cfg.model.decode_head.norm_cfg = cfg.norm_cfg cfg.model.decode_head.num_classes = num_class cfg.train_dataloader.batch_size = batch_size # Set training configurations cfg.train_cfg.max_iters = max_iters cfg.train_cfg.val_interval = val_interval cfg.default_hooks.logger.interval = 100 cfg.default_hooks.checkpoint.interval = 2500 cfg.default_hooks.checkpoint.max_keep_ckpts = 1 cfg.default_hooks.checkpoint.save_best = 'mIoU' cfg['randomness'] = dict(seed=0) # Set work directory cfg.save_dir = save_dir name = os.path.basename(dataset_config_path).split('_')[0] + "_" + os.path.dirname(model_config_path).split(os.sep)[1] cfg.work_dir = os.path.join(work_dir,name) os.makedirs(cfg.work_dir, exist_ok=True) save_config_file = os.path.join(save_dir, f"{name}.py") cfg.dump(save_config_file) print(f"Configuration saved to: {save_config_file}") def main(): parser = argparse.ArgumentParser(description='Train configuration setup for different models.') parser.add_argument('--model_name', type=str, required=True, choices=['deeplabv3plus', 'knet', 'mask2former', 'segformer'], help='Model name to generate the config for.') parser.add_argument('-m', '--model_config', type=str, required=True, help="Path to the model config file") parser.add_argument('-d', '--dataset_config', type=str, required=True, help='Path to the dataset config file.') parser.add_argument('-c', '--num_class', type=int, required=True, help="Number of classes in the dataset") parser.add_argument('-w','--work_dir', type=str, required=True, help='Directory to save the train result.') parser.add_argument('-s', '--save_dir', type=str, required=True, help="Directory to save the generated config file") parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.') parser.add_argument('--max_iters', type=int, default=20000, help='Number of training iterations.') parser.add_argument('--val_interval', type=int, default=500, help='Interval for validation during training.') args = parser.parse_args() if args.model_name == 'deeplabv3plus': create_deeplabv3plus_config( model_config_path=args.model_config, dataset_config_path=args.dataset_config, num_class=args.num_class, work_dir=args.work_dir, save_dir =args.save_dir, batch_size=args.batch_size, max_iters=args.max_iters, val_interval=args.val_interval ) if args.model_name == 'knet': create_knet_config( model_config_path=args.model_config, dataset_config_path=args.dataset_config, num_class=args.num_class, work_dir=args.work_dir, save_dir =args.save_dir, batch_size=args.batch_size, max_iters=args.max_iters, val_interval=args.val_interval ) if args.model_name == 'mask2former': create_mask2former_config( model_config_path=args.model_config, dataset_config_path=args.dataset_config, num_class=args.num_class, work_dir=args.work_dir, save_dir =args.save_dir, batch_size=args.batch_size, max_iters=args.max_iters, val_interval=args.val_interval ) elif args.model_name == 'segformer': create_segformer_config( model_config_path=args.model_config, dataset_config_path=args.dataset_config, num_class=args.num_class, work_dir=args.work_dir, save_dir =args.save_dir, batch_size=args.batch_size, max_iters=args.max_iters, val_interval=args.val_interval ) if __name__ == '__main__': main()