Spaces:
Build error
Build error
# -------------------------------------------------------- | |
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language | |
# Copyright (c) 2022 Microsoft | |
# Licensed under The MIT License [see LICENSE for details] | |
# Modified by Xueyan Zou ([email protected]) | |
# -------------------------------------------------------- | |
import logging | |
import os | |
import json | |
import random | |
import copy | |
import itertools | |
from typing import Any, Dict, List, Set, Union | |
from datetime import datetime | |
from mpi4py import MPI | |
import numpy as np | |
import torch | |
from torch.utils.data import DataLoader | |
from detectron2.projects.deeplab import build_lr_scheduler | |
from fvcore.common.config import CfgNode | |
from infinibatch import iterators | |
from utilities.distributed import is_main_process, get_world_size | |
from .default_trainer import DefaultTrainer | |
from .utils.serialization import JSONEncoder, filter_jsonable | |
logger = logging.getLogger(__name__) | |
class XDecoder_Trainer(DefaultTrainer): | |
""" | |
Construct Mask2Former_Trainer for optimizer and lr_scheduler | |
""" | |
def create_optimizer_and_scheduler(self): | |
""" | |
Set up self.optimizers and self.lr_schedulers | |
This method initializes self.optimizers and self.lr_schedulers as dictionaries of | |
instances of the classes that OPTIMIZER and LR_SCHEDULER in the config file points to. | |
One optimizer and lr scheduler for each model in self.raw_models. They have the same keys | |
as self.raw_models. | |
""" | |
self.opt['init_optimizer_in_deepspeed'] = False | |
self.opt['init_lr_scheduler_in_deepspeed'] = False | |
self.optimizers = {module_name: None for module_name in self.model_names} | |
self.lr_schedulers = {module_name: None for module_name in self.model_names} | |
cfg_solver = self.opt['SOLVER'] | |
weight_decay_norm = cfg_solver['WEIGHT_DECAY_NORM'] | |
weight_decay_embed = cfg_solver['WEIGHT_DECAY_EMBED'] | |
weight_decay_bias = cfg_solver.get('WEIGHT_DECAY_BIAS', 0.0) | |
defaults = {} | |
defaults["lr"] = cfg_solver['BASE_LR'] | |
defaults["weight_decay"] = cfg_solver['WEIGHT_DECAY'] | |
norm_module_types = ( | |
torch.nn.BatchNorm1d, | |
torch.nn.BatchNorm2d, | |
torch.nn.BatchNorm3d, | |
torch.nn.SyncBatchNorm, | |
# NaiveSyncBatchNorm inherits from BatchNorm2d | |
torch.nn.GroupNorm, | |
torch.nn.InstanceNorm1d, | |
torch.nn.InstanceNorm2d, | |
torch.nn.InstanceNorm3d, | |
torch.nn.LayerNorm, | |
torch.nn.LocalResponseNorm, | |
) | |
fix_param = self.opt['SOLVER'].get('FIX_PARAM',{}) | |
ignore_fix = self.opt['SOLVER'].get('IGNORE_FIX',[]) | |
for _module_name in self.model_names: | |
flag_continue = False | |
module_params = {} | |
for name, param in self.raw_models[_module_name].named_parameters(): | |
for ig in ignore_fix: | |
if ig in name: | |
flag_continue = True | |
break | |
if flag_continue: | |
flag_continue = False | |
continue | |
for key, value in fix_param.items(): | |
if key in name and value == True: | |
param.requires_grad = False | |
if key in name: | |
if key not in module_params: | |
module_params[key] = 0 | |
module_params[key] += param.numel() | |
logger.info(f"Module {_module_name} has parameters: {module_params}") | |
#raise NotImplementedError("Please check the fix_param and ignore_fix in the config file") | |
lr_multiplier = self.opt['SOLVER']['LR_MULTIPLIER'] | |
for _module_name in self.model_names: | |
# parameters = self.raw_models[module_name].get_training_parameters() | |
# self.optimizers[module_name] = optimizer_class(parameters, **optimizer_parameters) | |
# params = [] | |
# for module_param_name, value in self.raw_models[module_name].named_parameters(recurse=True): | |
params: List[Dict[str, Any]] = [] | |
memo: Set[torch.nn.parameter.Parameter] = set() | |
for module_name, module in self.raw_models[_module_name].named_modules(): | |
for module_param_name, value in module.named_parameters(recurse=False): | |
if not value.requires_grad: | |
continue | |
# Avoid duplicating parameters | |
if value in memo: | |
continue | |
memo.add(value) | |
hyperparams = copy.copy(defaults) | |
for key, lr_mul in lr_multiplier.items(): | |
if key in "{}.{}".format(module_name, module_param_name): | |
hyperparams["lr"] = hyperparams["lr"] * lr_mul | |
if is_main_process(): | |
logger.info("Modify Learning rate of {}: {}".format("{}.{}".format(module_name, module_param_name), lr_mul)) | |
if ( | |
"relative_position_bias_table" in module_param_name | |
or "absolute_pos_embed" in module_param_name | |
): | |
hyperparams["weight_decay"] = 0.0 | |
if isinstance(module, norm_module_types): | |
hyperparams["weight_decay"] = weight_decay_norm | |
if isinstance(module, torch.nn.Embedding): | |
hyperparams["weight_decay"] = weight_decay_embed | |
if "bias" in module_name: | |
hyperparams["weight_decay"] = weight_decay_bias | |
params.append({"params": [value], **hyperparams}) | |
def maybe_add_full_model_gradient_clipping(optim): | |
# detectron2 doesn't have full model gradient clipping now | |
clip_norm_val = cfg_solver['CLIP_GRADIENTS']['CLIP_VALUE'] | |
enable = ( | |
cfg_solver['CLIP_GRADIENTS']['ENABLED'] | |
and cfg_solver['CLIP_GRADIENTS']['CLIP_TYPE'] == "full_model" | |
and clip_norm_val > 0.0 | |
) | |
class FullModelGradientClippingOptimizer(optim): | |
def step(self, closure=None): | |
all_params = itertools.chain(*[x["params"] for x in self.param_groups]) | |
torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val) | |
super().step(closure=closure) | |
return FullModelGradientClippingOptimizer if enable else optim | |
optimizer_type = cfg_solver['OPTIMIZER'] | |
if optimizer_type == "SGD": | |
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)( | |
params, cfg_solver['BASE_LR'], momentum=cfg_solver['MOMENTUM'] | |
) | |
elif optimizer_type == "ADAMW": | |
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)( | |
params, cfg_solver['BASE_LR'] | |
) | |
else: | |
raise NotImplementedError(f"no optimizer type {optimizer_type}") | |
self.optimizers[_module_name] = optimizer | |
self.optimizers[_module_name].zero_grad() | |
num_epoch = self.opt['SOLVER']['MAX_NUM_EPOCHS'] | |
cfg_solver['MAX_ITER'] = num_epoch * self.train_params['updates_per_epoch'] | |
cfg_solver['STEPS'] = [int(x*cfg_solver['MAX_ITER']) for x in cfg_solver['STEPS']] | |
logger.info(f"Calculate MAX_ITER @ {cfg_solver['MAX_ITER']} and STEPS @ {cfg_solver['STEPS']}") | |
for module_name in self.model_names: | |
scheduler_cfg = CfgNode({'SOLVER': cfg_solver}) | |
self.lr_schedulers[module_name] = build_lr_scheduler(scheduler_cfg, self.optimizers[module_name]) | |
for module_name in self.model_names: | |
num_params = 0 | |
num_trainable_params = 0 | |
for name, param in self.raw_models[module_name].named_parameters(): | |
num_params += param.numel() | |
if param.requires_grad: | |
num_trainable_params += param.numel() | |
logger.info(f"Total number of parameters in {module_name} module (on each GPU): {num_params}") | |
logger.info(f"Number of trainable parameters in {module_name} module (on each GPU): {num_trainable_params}") |