|
import os
|
|
import argparse
|
|
from mmengine import Config
|
|
|
|
def create_deeplabv3plus_config(model_config_path, dataset_config_path, num_class, work_dir, save_dir, batch_size, max_iters, val_interval):
|
|
cfg = Config.fromfile(model_config_path)
|
|
dataset_cfg = Config.fromfile(dataset_config_path)
|
|
cfg.merge_from_dict(dataset_cfg)
|
|
|
|
|
|
cfg.crop_size = (512, 512)
|
|
cfg.model.data_preprocessor.size = cfg.crop_size
|
|
|
|
|
|
cfg.norm_cfg = dict(type='BN', requires_grad=True)
|
|
cfg.model.backbone.norm_cfg = cfg.norm_cfg
|
|
cfg.model.decode_head.norm_cfg = cfg.norm_cfg
|
|
cfg.model.auxiliary_head.norm_cfg = cfg.norm_cfg
|
|
|
|
cfg.model.decode_head.num_classes = num_class
|
|
cfg.model.auxiliary_head.num_classes = num_class
|
|
|
|
cfg.train_dataloader.batch_size = batch_size
|
|
|
|
|
|
cfg.train_cfg.max_iters = max_iters
|
|
cfg.train_cfg.val_interval = val_interval
|
|
cfg.default_hooks.logger.interval = 100
|
|
cfg.default_hooks.checkpoint.interval = 2500
|
|
cfg.default_hooks.checkpoint.max_keep_ckpts = 1
|
|
cfg.default_hooks.checkpoint.save_best = 'mIoU'
|
|
|
|
cfg['randomness'] = dict(seed=0)
|
|
|
|
cfg.save_dir = save_dir
|
|
name = os.path.basename(dataset_config_path).split('_')[0] + "_" + os.path.dirname(model_config_path).split(os.sep)[1]
|
|
cfg.work_dir = os.path.join(work_dir,name)
|
|
os.makedirs(cfg.work_dir, exist_ok=True)
|
|
save_config_file = os.path.join(save_dir, f"{name}.py")
|
|
cfg.dump(save_config_file)
|
|
print(f"Configuration saved to: {save_config_file}")
|
|
|
|
def create_knet_config(model_config_path, dataset_config_path, num_class, work_dir, save_dir, batch_size, max_iters, val_interval):
|
|
|
|
cfg = Config.fromfile(model_config_path)
|
|
dataset_cfg = Config.fromfile(dataset_config_path)
|
|
|
|
cfg.merge_from_dict(dataset_cfg)
|
|
|
|
cfg.norm_cfg = dict(type='BN', requires_grad=True)
|
|
cfg.model.data_preprocessor.size = cfg.crop_size
|
|
|
|
cfg.model.decode_head.kernel_generate_head.num_classes = num_class
|
|
cfg.model.auxiliary_head.num_classes = num_class
|
|
|
|
cfg.train_dataloader.batch_size = batch_size
|
|
cfg.work_dir = work_dir
|
|
|
|
cfg.train_cfg.max_iters = max_iters
|
|
cfg.train_cfg.val_interval = val_interval
|
|
cfg.default_hooks.logger.interval = 100
|
|
cfg.default_hooks.checkpoint.interval = 2500
|
|
cfg.default_hooks.checkpoint.max_keep_ckpts = 1
|
|
cfg.default_hooks.checkpoint.save_best = 'mIoU'
|
|
|
|
cfg['randomness'] = dict(seed=0)
|
|
|
|
cfg.save_dir = save_dir
|
|
name = os.path.basename(dataset_config_path).split('_')[0] + "_" + os.path.dirname(model_config_path).split(os.sep)[1]
|
|
cfg.work_dir = os.path.join(work_dir, name)
|
|
os.makedirs(cfg.work_dir, exist_ok=True)
|
|
save_config_file = os.path.join(save_dir, f"{name}.py")
|
|
cfg.dump(save_config_file)
|
|
print(f"Configuration saved to: {save_config_file}")
|
|
|
|
def create_mask2former_config(model_config_path, dataset_config_path, num_class, work_dir, save_dir, batch_size, max_iters, val_interval):
|
|
cfg = Config.fromfile(model_config_path)
|
|
dataset_cfg = Config.fromfile(dataset_config_path)
|
|
cfg.merge_from_dict(dataset_cfg)
|
|
|
|
|
|
cfg.crop_size = (512, 512)
|
|
cfg.model.data_preprocessor.size = cfg.crop_size
|
|
|
|
|
|
cfg.norm_cfg = dict(type='BN', requires_grad=True)
|
|
|
|
cfg.model.decode_head.num_classes = num_class
|
|
cfg.model.decode_head.loss_cls.class_weight = [1.0] * num_class + [0.1]
|
|
|
|
cfg.train_dataloader.batch_size = batch_size
|
|
|
|
|
|
cfg.train_cfg.max_iters = max_iters
|
|
cfg.train_cfg.val_interval = val_interval
|
|
cfg.default_hooks.logger.interval = 100
|
|
cfg.default_hooks.checkpoint.interval = 2500
|
|
cfg.default_hooks.checkpoint.max_keep_ckpts = 1
|
|
cfg.default_hooks.checkpoint.save_best = 'mIoU'
|
|
|
|
cfg['randomness'] = dict(seed=0)
|
|
|
|
cfg.save_dir = save_dir
|
|
name = os.path.basename(dataset_config_path).split('_')[0] + "_" + os.path.dirname(model_config_path).split(os.sep)[1]
|
|
cfg.work_dir = os.path.join(work_dir,name)
|
|
os.makedirs(cfg.work_dir, exist_ok=True)
|
|
save_config_file = os.path.join(save_dir, f"{name}.py")
|
|
cfg.dump(save_config_file)
|
|
print(f"Configuration saved to: {save_config_file}")
|
|
|
|
def create_segformer_config(model_config_path, dataset_config_path, num_class, work_dir, save_dir, batch_size, max_iters, val_interval):
|
|
cfg = Config.fromfile(model_config_path)
|
|
dataset_cfg = Config.fromfile(dataset_config_path)
|
|
cfg.merge_from_dict(dataset_cfg)
|
|
|
|
|
|
cfg.norm_cfg = dict(type='BN', requires_grad=True)
|
|
cfg.model.data_preprocessor.size = cfg.crop_size
|
|
cfg.model.decode_head.norm_cfg = cfg.norm_cfg
|
|
|
|
cfg.model.decode_head.num_classes = num_class
|
|
|
|
cfg.train_dataloader.batch_size = batch_size
|
|
|
|
|
|
cfg.train_cfg.max_iters = max_iters
|
|
cfg.train_cfg.val_interval = val_interval
|
|
cfg.default_hooks.logger.interval = 100
|
|
cfg.default_hooks.checkpoint.interval = 2500
|
|
cfg.default_hooks.checkpoint.max_keep_ckpts = 1
|
|
cfg.default_hooks.checkpoint.save_best = 'mIoU'
|
|
|
|
cfg['randomness'] = dict(seed=0)
|
|
|
|
cfg.save_dir = save_dir
|
|
name = os.path.basename(dataset_config_path).split('_')[0] + "_" + os.path.dirname(model_config_path).split(os.sep)[1]
|
|
cfg.work_dir = os.path.join(work_dir,name)
|
|
os.makedirs(cfg.work_dir, exist_ok=True)
|
|
save_config_file = os.path.join(save_dir, f"{name}.py")
|
|
cfg.dump(save_config_file)
|
|
print(f"Configuration saved to: {save_config_file}")
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Train configuration setup for different models.')
|
|
|
|
parser.add_argument('--model_name', type=str, required=True, choices=['deeplabv3plus', 'knet', 'mask2former', 'segformer'],
|
|
help='Model name to generate the config for.')
|
|
parser.add_argument('-m', '--model_config', type=str, required=True, help="Path to the model config file")
|
|
parser.add_argument('-d', '--dataset_config', type=str, required=True, help='Path to the dataset config file.')
|
|
parser.add_argument('-c', '--num_class', type=int, required=True, help="Number of classes in the dataset")
|
|
parser.add_argument('-w','--work_dir', type=str, required=True, help='Directory to save the train result.')
|
|
parser.add_argument('-s', '--save_dir', type=str, required=True, help="Directory to save the generated config file")
|
|
parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.')
|
|
parser.add_argument('--max_iters', type=int, default=20000, help='Number of training iterations.')
|
|
parser.add_argument('--val_interval', type=int, default=500, help='Interval for validation during training.')
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.model_name == 'deeplabv3plus':
|
|
create_deeplabv3plus_config(
|
|
model_config_path=args.model_config,
|
|
dataset_config_path=args.dataset_config,
|
|
num_class=args.num_class,
|
|
work_dir=args.work_dir,
|
|
save_dir =args.save_dir,
|
|
batch_size=args.batch_size,
|
|
max_iters=args.max_iters,
|
|
val_interval=args.val_interval
|
|
)
|
|
if args.model_name == 'knet':
|
|
create_knet_config(
|
|
model_config_path=args.model_config,
|
|
dataset_config_path=args.dataset_config,
|
|
num_class=args.num_class,
|
|
work_dir=args.work_dir,
|
|
save_dir =args.save_dir,
|
|
batch_size=args.batch_size,
|
|
max_iters=args.max_iters,
|
|
val_interval=args.val_interval
|
|
)
|
|
if args.model_name == 'mask2former':
|
|
create_mask2former_config(
|
|
model_config_path=args.model_config,
|
|
dataset_config_path=args.dataset_config,
|
|
num_class=args.num_class,
|
|
work_dir=args.work_dir,
|
|
save_dir =args.save_dir,
|
|
batch_size=args.batch_size,
|
|
max_iters=args.max_iters,
|
|
val_interval=args.val_interval
|
|
)
|
|
elif args.model_name == 'segformer':
|
|
create_segformer_config(
|
|
model_config_path=args.model_config,
|
|
dataset_config_path=args.dataset_config,
|
|
num_class=args.num_class,
|
|
work_dir=args.work_dir,
|
|
save_dir =args.save_dir,
|
|
batch_size=args.batch_size,
|
|
max_iters=args.max_iters,
|
|
val_interval=args.val_interval
|
|
)
|
|
|
|
if __name__ == '__main__':
|
|
main()
|
|
|