""" This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”). All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates. Reference: https://github.com/facebookresearch/Mask2Former/blob/main/train_net.py MAFT-Plus Training Script. This script is a simplified version of the training script in detectron2/tools. """ try: # ignore ShapelyDeprecationWarning from fvcore from shapely.errors import ShapelyDeprecationWarning import warnings warnings.filterwarnings('ignore', category=ShapelyDeprecationWarning) except: pass import copy import itertools import logging import os # os.environ['CUDA_VISIBLE_DEVICES'] = '2,4,6' from collections import OrderedDict from typing import Any, Dict, List, Set import torch import detectron2.utils.comm as comm from detectron2.checkpoint import DetectionCheckpointer from detectron2.config import get_cfg from detectron2.data import MetadataCatalog, build_detection_train_loader from detectron2.engine import ( DefaultTrainer, default_argument_parser, default_setup, launch, ) from detectron2.evaluation import ( CityscapesInstanceEvaluator, CityscapesSemSegEvaluator, COCOEvaluator, COCOPanopticEvaluator, DatasetEvaluators, LVISEvaluator, SemSegEvaluator, verify_results, ) from detectron2.projects.deeplab import add_deeplab_config, build_lr_scheduler from detectron2.solver.build import maybe_add_gradient_clipping from detectron2.utils.logger import setup_logger from maft import ( COCOInstanceNewBaselineDatasetMapper, COCOPanopticNewBaselineDatasetMapper, COCOSemanticNewBaselineDatasetMapper, InstanceSegEvaluator, #SemSegEvaluator, MaskFormerInstanceDatasetMapper, MaskFormerPanopticDatasetMapper, MaskFormerSemanticDatasetMapper, SemanticSegmentorWithTTA, add_maskformer2_config, add_fcclip_config, add_mask_adapter_config, ) class Trainer(DefaultTrainer): """ Extension of the Trainer class adapted to FCCLIP. """ @classmethod def build_evaluator(cls, cfg, dataset_name, output_folder=None): """ Create evaluator(s) for a given dataset. This uses the special metadata "evaluator_type" associated with each builtin dataset. For your own dataset, you can simply create an evaluator manually in your script and do not have to worry about the hacky if-else logic here. """ if output_folder is None: output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") evaluator_list = [] evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type # semantic segmentation if cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_semantic": evaluator_list.append( SemSegEvaluator( dataset_name, distributed=True, output_dir=output_folder, ) ) # panoptic segmentation elif cfg.INPUT.DATASET_MAPPER_NAME == "coco_panoptic_lsj": evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_dir=output_folder)) if len(evaluator_list) == 0: raise NotImplementedError( "no Evaluator for the dataset {} with the type {}".format( dataset_name, evaluator_type ) ) elif len(evaluator_list) == 1: return evaluator_list[0] return DatasetEvaluators(evaluator_list) @classmethod def build_train_loader(cls, cfg): # Semantic segmentation dataset mapper if cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_semantic": mapper = MaskFormerSemanticDatasetMapper(cfg, True) return build_detection_train_loader(cfg, mapper=mapper) # Panoptic segmentation dataset mapper elif cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_panoptic": mapper = MaskFormerPanopticDatasetMapper(cfg, True) return build_detection_train_loader(cfg, mapper=mapper) # Instance segmentation dataset mapper elif cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_instance": mapper = MaskFormerInstanceDatasetMapper(cfg, True) return build_detection_train_loader(cfg, mapper=mapper) # coco instance segmentation lsj new baseline elif cfg.INPUT.DATASET_MAPPER_NAME == "coco_instance_lsj": mapper = COCOInstanceNewBaselineDatasetMapper(cfg, True) return build_detection_train_loader(cfg, mapper=mapper) # coco panoptic segmentation lsj new baseline elif cfg.INPUT.DATASET_MAPPER_NAME == "coco_panoptic_lsj": mapper = COCOPanopticNewBaselineDatasetMapper(cfg, True) return build_detection_train_loader(cfg, mapper=mapper) # coco panoptic segmentation lsj new baseline elif cfg.INPUT.DATASET_MAPPER_NAME == "coco_semantic_lsj": mapper = COCOSemanticNewBaselineDatasetMapper(cfg, True) return build_detection_train_loader(cfg, mapper=mapper) else: mapper = None return build_detection_train_loader(cfg, mapper=mapper) @classmethod def build_lr_scheduler(cls, cfg, optimizer): """ It now calls :func:`detectron2.solver.build_lr_scheduler`. Overwrite it if you'd like a different scheduler. """ return build_lr_scheduler(cfg, optimizer) @classmethod def build_optimizer(cls, cfg, model): weight_decay_norm = cfg.SOLVER.WEIGHT_DECAY_NORM weight_decay_embed = cfg.SOLVER.WEIGHT_DECAY_EMBED defaults = {} defaults["lr"] = cfg.SOLVER.BASE_LR defaults["weight_decay"] = cfg.SOLVER.WEIGHT_DECAY norm_module_types = ( torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm, # NaiveSyncBatchNorm inherits from BatchNorm2d torch.nn.GroupNorm, torch.nn.InstanceNorm1d, torch.nn.InstanceNorm2d, torch.nn.InstanceNorm3d, torch.nn.LayerNorm, torch.nn.LocalResponseNorm, ) params: List[Dict[str, Any]] = [] memo: Set[torch.nn.parameter.Parameter] = set() for module_name, module in model.named_modules(): for module_param_name, value in module.named_parameters(recurse=False): if not value.requires_grad: continue # Avoid duplicating parameters if value in memo: continue memo.add(value) hyperparams = copy.copy(defaults) if "backbone" in module_name: hyperparams["lr"] = hyperparams["lr"] * cfg.SOLVER.BACKBONE_MULTIPLIER if ( "relative_position_bias_table" in module_param_name or "absolute_pos_embed" in module_param_name ): print(module_param_name) hyperparams["weight_decay"] = 0.0 if isinstance(module, norm_module_types): hyperparams["weight_decay"] = weight_decay_norm if isinstance(module, torch.nn.Embedding): hyperparams["weight_decay"] = weight_decay_embed params.append({"params": [value], **hyperparams}) def maybe_add_full_model_gradient_clipping(optim): # detectron2 doesn't have full model gradient clipping now clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE enable = ( cfg.SOLVER.CLIP_GRADIENTS.ENABLED and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model" and clip_norm_val > 0.0 ) class FullModelGradientClippingOptimizer(optim): def step(self, closure=None): all_params = itertools.chain(*[x["params"] for x in self.param_groups]) torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val) super().step(closure=closure) return FullModelGradientClippingOptimizer if enable else optim optimizer_type = cfg.SOLVER.OPTIMIZER if optimizer_type == "SGD": optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)( params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM ) elif optimizer_type == "ADAMW": optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)( params, cfg.SOLVER.BASE_LR ) else: raise NotImplementedError(f"no optimizer type {optimizer_type}") if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model": optimizer = maybe_add_gradient_clipping(cfg, optimizer) return optimizer @classmethod def test_with_TTA(cls, cfg, model): logger = logging.getLogger("detectron2.trainer") # In the end of training, run an evaluation with TTA. logger.info("Running inference with test-time augmentation ...") model = SemanticSegmentorWithTTA(cfg, model) evaluators = [ cls.build_evaluator( cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA") ) for name in cfg.DATASETS.TEST ] res = cls.test(cfg, model, evaluators) res = OrderedDict({k + "_TTA": v for k, v in res.items()}) return res def setup(args): """ Create configs and perform basic setups. """ cfg = get_cfg() # for poly lr schedule add_deeplab_config(cfg) add_maskformer2_config(cfg) add_fcclip_config(cfg) add_mask_adapter_config(cfg) cfg.merge_from_file(args.config_file) cfg.merge_from_list(args.opts) cfg.merge_from_list(['SEED', 123]) cfg.freeze() default_setup(cfg, args) # Setup logger for "maft-plus" module setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="maft-plus") return cfg def main(args): # torch.multiprocessing.set_start_method('spawn') cfg = setup(args) if args.eval_only: model = Trainer.build_model(cfg) DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=args.resume ) res = Trainer.test(cfg, model) if cfg.TEST.AUG.ENABLED: res.update(Trainer.test_with_TTA(cfg, model)) if comm.is_main_process(): verify_results(cfg, res) return res trainer = Trainer(cfg) trainer.resume_or_load(resume=args.resume) return trainer.train() if __name__ == "__main__": args = default_argument_parser().parse_args() print("Command Line Args:", args) launch( main, args.num_gpus, num_machines=args.num_machines, machine_rank=args.machine_rank, dist_url=args.dist_url, args=(args,), )