Model / run /run_configs.py
CVRPDataset's picture
Upload 7 files
1060621 verified
import os
import argparse
from mmengine import Config
def create_deeplabv3plus_config(model_config_path, dataset_config_path, num_class, work_dir, save_dir, batch_size, max_iters, val_interval):
cfg = Config.fromfile(model_config_path)
dataset_cfg = Config.fromfile(dataset_config_path)
cfg.merge_from_dict(dataset_cfg)
# Set crop size
cfg.crop_size = (512, 512)
cfg.model.data_preprocessor.size = cfg.crop_size
# Configure normalization
cfg.norm_cfg = dict(type='BN', requires_grad=True)
cfg.model.backbone.norm_cfg = cfg.norm_cfg
cfg.model.decode_head.norm_cfg = cfg.norm_cfg
cfg.model.auxiliary_head.norm_cfg = cfg.norm_cfg
cfg.model.decode_head.num_classes = num_class
cfg.model.auxiliary_head.num_classes = num_class
cfg.train_dataloader.batch_size = batch_size
# Set training configurations
cfg.train_cfg.max_iters = max_iters
cfg.train_cfg.val_interval = val_interval
cfg.default_hooks.logger.interval = 100
cfg.default_hooks.checkpoint.interval = 2500
cfg.default_hooks.checkpoint.max_keep_ckpts = 1
cfg.default_hooks.checkpoint.save_best = 'mIoU'
cfg['randomness'] = dict(seed=0)
# Set work directory
cfg.save_dir = save_dir
name = os.path.basename(dataset_config_path).split('_')[0] + "_" + os.path.dirname(model_config_path).split(os.sep)[1]
cfg.work_dir = os.path.join(work_dir,name)
os.makedirs(cfg.work_dir, exist_ok=True)
save_config_file = os.path.join(save_dir, f"{name}.py")
cfg.dump(save_config_file)
print(f"Configuration saved to: {save_config_file}")
def create_knet_config(model_config_path, dataset_config_path, num_class, work_dir, save_dir, batch_size, max_iters, val_interval):
cfg = Config.fromfile(model_config_path)
dataset_cfg = Config.fromfile(dataset_config_path)
cfg.merge_from_dict(dataset_cfg)
cfg.norm_cfg = dict(type='BN', requires_grad=True)
cfg.model.data_preprocessor.size = cfg.crop_size
cfg.model.decode_head.kernel_generate_head.num_classes = num_class
cfg.model.auxiliary_head.num_classes = num_class
cfg.train_dataloader.batch_size = batch_size
cfg.work_dir = work_dir
cfg.train_cfg.max_iters = max_iters
cfg.train_cfg.val_interval = val_interval
cfg.default_hooks.logger.interval = 100
cfg.default_hooks.checkpoint.interval = 2500
cfg.default_hooks.checkpoint.max_keep_ckpts = 1
cfg.default_hooks.checkpoint.save_best = 'mIoU'
cfg['randomness'] = dict(seed=0)
cfg.save_dir = save_dir
name = os.path.basename(dataset_config_path).split('_')[0] + "_" + os.path.dirname(model_config_path).split(os.sep)[1]
cfg.work_dir = os.path.join(work_dir, name)
os.makedirs(cfg.work_dir, exist_ok=True)
save_config_file = os.path.join(save_dir, f"{name}.py")
cfg.dump(save_config_file)
print(f"Configuration saved to: {save_config_file}")
def create_mask2former_config(model_config_path, dataset_config_path, num_class, work_dir, save_dir, batch_size, max_iters, val_interval):
cfg = Config.fromfile(model_config_path)
dataset_cfg = Config.fromfile(dataset_config_path)
cfg.merge_from_dict(dataset_cfg)
# Set crop size
cfg.crop_size = (512, 512)
cfg.model.data_preprocessor.size = cfg.crop_size
# Configure normalization
cfg.norm_cfg = dict(type='BN', requires_grad=True)
cfg.model.decode_head.num_classes = num_class
cfg.model.decode_head.loss_cls.class_weight = [1.0] * num_class + [0.1]
cfg.train_dataloader.batch_size = batch_size
# Set training configurations
cfg.train_cfg.max_iters = max_iters
cfg.train_cfg.val_interval = val_interval
cfg.default_hooks.logger.interval = 100
cfg.default_hooks.checkpoint.interval = 2500
cfg.default_hooks.checkpoint.max_keep_ckpts = 1
cfg.default_hooks.checkpoint.save_best = 'mIoU'
cfg['randomness'] = dict(seed=0)
# Set work directory
cfg.save_dir = save_dir
name = os.path.basename(dataset_config_path).split('_')[0] + "_" + os.path.dirname(model_config_path).split(os.sep)[1]
cfg.work_dir = os.path.join(work_dir,name)
os.makedirs(cfg.work_dir, exist_ok=True)
save_config_file = os.path.join(save_dir, f"{name}.py")
cfg.dump(save_config_file)
print(f"Configuration saved to: {save_config_file}")
def create_segformer_config(model_config_path, dataset_config_path, num_class, work_dir, save_dir, batch_size, max_iters, val_interval):
cfg = Config.fromfile(model_config_path)
dataset_cfg = Config.fromfile(dataset_config_path)
cfg.merge_from_dict(dataset_cfg)
# Configure normalization
cfg.norm_cfg = dict(type='BN', requires_grad=True)
cfg.model.data_preprocessor.size = cfg.crop_size
cfg.model.decode_head.norm_cfg = cfg.norm_cfg
cfg.model.decode_head.num_classes = num_class
cfg.train_dataloader.batch_size = batch_size
# Set training configurations
cfg.train_cfg.max_iters = max_iters
cfg.train_cfg.val_interval = val_interval
cfg.default_hooks.logger.interval = 100
cfg.default_hooks.checkpoint.interval = 2500
cfg.default_hooks.checkpoint.max_keep_ckpts = 1
cfg.default_hooks.checkpoint.save_best = 'mIoU'
cfg['randomness'] = dict(seed=0)
# Set work directory
cfg.save_dir = save_dir
name = os.path.basename(dataset_config_path).split('_')[0] + "_" + os.path.dirname(model_config_path).split(os.sep)[1]
cfg.work_dir = os.path.join(work_dir,name)
os.makedirs(cfg.work_dir, exist_ok=True)
save_config_file = os.path.join(save_dir, f"{name}.py")
cfg.dump(save_config_file)
print(f"Configuration saved to: {save_config_file}")
def main():
parser = argparse.ArgumentParser(description='Train configuration setup for different models.')
parser.add_argument('--model_name', type=str, required=True, choices=['deeplabv3plus', 'knet', 'mask2former', 'segformer'],
help='Model name to generate the config for.')
parser.add_argument('-m', '--model_config', type=str, required=True, help="Path to the model config file")
parser.add_argument('-d', '--dataset_config', type=str, required=True, help='Path to the dataset config file.')
parser.add_argument('-c', '--num_class', type=int, required=True, help="Number of classes in the dataset")
parser.add_argument('-w','--work_dir', type=str, required=True, help='Directory to save the train result.')
parser.add_argument('-s', '--save_dir', type=str, required=True, help="Directory to save the generated config file")
parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.')
parser.add_argument('--max_iters', type=int, default=20000, help='Number of training iterations.')
parser.add_argument('--val_interval', type=int, default=500, help='Interval for validation during training.')
args = parser.parse_args()
if args.model_name == 'deeplabv3plus':
create_deeplabv3plus_config(
model_config_path=args.model_config,
dataset_config_path=args.dataset_config,
num_class=args.num_class,
work_dir=args.work_dir,
save_dir =args.save_dir,
batch_size=args.batch_size,
max_iters=args.max_iters,
val_interval=args.val_interval
)
if args.model_name == 'knet':
create_knet_config(
model_config_path=args.model_config,
dataset_config_path=args.dataset_config,
num_class=args.num_class,
work_dir=args.work_dir,
save_dir =args.save_dir,
batch_size=args.batch_size,
max_iters=args.max_iters,
val_interval=args.val_interval
)
if args.model_name == 'mask2former':
create_mask2former_config(
model_config_path=args.model_config,
dataset_config_path=args.dataset_config,
num_class=args.num_class,
work_dir=args.work_dir,
save_dir =args.save_dir,
batch_size=args.batch_size,
max_iters=args.max_iters,
val_interval=args.val_interval
)
elif args.model_name == 'segformer':
create_segformer_config(
model_config_path=args.model_config,
dataset_config_path=args.dataset_config,
num_class=args.num_class,
work_dir=args.work_dir,
save_dir =args.save_dir,
batch_size=args.batch_size,
max_iters=args.max_iters,
val_interval=args.val_interval
)
if __name__ == '__main__':
main()