File size: 22,879 Bytes
32b542e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
import copy
import torch
import itertools
from enum import Enum
from uniperceiver.config import CfgNode
from uniperceiver.utils.registry import Registry
from uniperceiver.utils import comm

from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, Union

SOLVER_REGISTRY = Registry("SOLVER")
SOLVER_REGISTRY.__doc__ = """
Registry for SOLVER.
"""

_GradientClipperInput = Union[torch.Tensor, Iterable[torch.Tensor]]
_GradientClipper = Callable[[_GradientClipperInput], None]

def _create_gradient_clipper(cfg: CfgNode) -> _GradientClipper:
    def clip_grad_norm(p: _GradientClipperInput):
        torch.nn.utils.clip_grad_norm_(p, cfg.SOLVER.GRAD_CLIP, cfg.SOLVER.NORM_TYPE)

    def clip_grad_value(p: _GradientClipperInput):
        torch.nn.utils.clip_grad_value_(p, cfg.SOLVER.GRAD_CLIP)

    _GRADIENT_CLIP_TYPE_TO_CLIPPER = {
        'value': clip_grad_value,
        'norm': clip_grad_norm,
    }
    clipper = _GRADIENT_CLIP_TYPE_TO_CLIPPER[cfg.SOLVER.GRAD_CLIP_TYPE]
    if cfg.SOLVER.GRAD_CLIP_TYPE == 'value':
        return clipper, None
    else:
        return None, clipper


def get_default_optimizer_params(
    model: torch.nn.Module,
    base_lr: Optional[float] = None,
    weight_decay: Optional[float] = None,
    weight_decay_norm: Optional[float] = None,
    bias_lr_factor: Optional[float] = 1.0,
    weight_decay_bias: Optional[float] = None,
    overrides: Optional[Dict[str, Dict[str, float]]] = None,
):
    if weight_decay_bias is None:
        weight_decay_bias = weight_decay
    norm_module_types = (
        torch.nn.BatchNorm1d,
        torch.nn.BatchNorm2d,
        torch.nn.BatchNorm3d,
        torch.nn.SyncBatchNorm,
        # NaiveSyncBatchNorm inherits from BatchNorm2d
        torch.nn.GroupNorm,
        torch.nn.InstanceNorm1d,
        torch.nn.InstanceNorm2d,
        torch.nn.InstanceNorm3d,
        torch.nn.LayerNorm,
        torch.nn.LocalResponseNorm,
    )
    params: List[Dict[str, Any]] = []
    memo: Set[torch.nn.parameter.Parameter] = set()

    no_decay_list = {}
    if hasattr(model, 'no_weight_decay'):
        no_decay_list = model.no_weight_decay()

    for module_name, module in model.named_modules():
        no_decay = False
        if module_name in no_decay_list:
            no_decay = True
        for module_param_name, value in module.named_parameters(recurse=False):
            if not value.requires_grad:
                continue
            # Avoid duplicating parameters
            if value in memo:
                continue
            memo.add(value)

            schedule_params = {
                "lr": base_lr,
                "weight_decay": weight_decay,
            }


            if isinstance(module, norm_module_types):
                schedule_params["weight_decay"] = weight_decay_norm
            elif module_param_name == "bias":
                # NOTE: unlike Detectron v1, we now default BIAS_LR_FACTOR to 1.0
                # and WEIGHT_DECAY_BIAS to WEIGHT_DECAY so that bias optimizer
                # hyperparameters are by default exactly the same as for regular
                # weights.
                schedule_params["lr"] = base_lr * bias_lr_factor
                schedule_params["weight_decay"] = weight_decay_bias

            if no_decay or (module_param_name in no_decay_list):
                schedule_params["weight_decay"] = 0.


            if overrides is not None and module_param_name in overrides:
                schedule_params.update(overrides[module_param_name])
            params += [
                {
                    "params": [value],
                    "lr": schedule_params["lr"],
                    "weight_decay": schedule_params["weight_decay"],
                }
            ]

    return params

def get_layer_id(module_name, num_layers):
    """
    Assign a parameter with its layer id
    modified from BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
    """
    if module_name.split('.')[0] in [
            'video_embed', 'token_embed', 'prompt_embed', 'visual_embed', 'cls_token' ''
    ]:
        return 0
    elif module_name.startswith('encoder'):
        return int(module_name.split('.')[2]) + 1
    elif module_name.startswith('predictor'):
        return num_layers
    else:
        raise NotImplementedError('please check this layer')

def create_seperate_moe_param_groups(
    model,
    base_lr: Optional[float] = None,
    weight_decay: Optional[float] = None,
    weight_decay_norm: Optional[float] = None,
    bias_lr_factor: Optional[float] = 1.0,
    wg_lr_facetor: Optional[float] = 1.0,
    weight_decay_bias: Optional[float] = None,
    weight_decay_embedding: Optional[float] = None,
    weight_decay_wg: Optional[float] = None,
    cfg: dict = None,
):
    try:
        from deepspeed.moe.utils import is_moe_param
    except:
        def is_moe_param(param: torch.Tensor) -> bool:
            if hasattr(param, "allreduce") and not param.allreduce:
                return True
            return False

    params: List[Dict[str, Any]] = []
    memo: Set[torch.nn.parameter.Parameter] = set()

    num_layers = cfg.MODEL.BERT.NUM_HIDDEN_LAYERS + 1
    layer_decay = cfg.SOLVER.LAYER_LR_DECAY
    layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))


    if weight_decay_bias is None:
        weight_decay_bias = weight_decay
    norm_module_types = (
        torch.nn.BatchNorm1d,
        torch.nn.BatchNorm2d,
        torch.nn.BatchNorm3d,
        torch.nn.SyncBatchNorm,
        # NaiveSyncBatchNorm inherits from BatchNorm2d
        torch.nn.GroupNorm,
        torch.nn.InstanceNorm1d,
        torch.nn.InstanceNorm2d,
        torch.nn.InstanceNorm3d,
        torch.nn.LayerNorm,
        torch.nn.LocalResponseNorm,
    )




    no_decay_list = {}
    if hasattr(model, 'no_weight_decay'):
        no_decay_list = model.no_weight_decay()

    wg_list = {}
    if hasattr(model, 'expert_gate_group'):
        wg_list = model.expert_gate_group()



    for module_name, module in model.named_modules():
        no_decay = False
        if module_name in no_decay_list:
            no_decay = True
        is_wg_param = False
        for wg_name in wg_list:
            if wg_name in module_name:
                is_wg_param = True
                continue

        for module_param_name, value in module.named_parameters(recurse=False):
            # layer_id = get_layer_id(module_name, num_layers)
            this_scale = layer_scales[ get_layer_id(module_name, num_layers)] if layer_decay < 1.0 else 1.0
            # if isinstance(module, torch.nn.Embedding):
            #     print(module_name, module_param_name)
            if not value.requires_grad:
                continue
            # Avoid duplicating parameters
            if value in memo:
                continue
            memo.add(value)
            schedule_params = {
                "lr": base_lr,
                "weight_decay": weight_decay,
                "moe": False,
            }
            if is_moe_param(value):
                schedule_params['moe'] = True

            if no_decay or (module_param_name in no_decay_list):
                schedule_params["weight_decay"] = 0.
            elif is_wg_param and isinstance(
                    module,
                    torch.nn.Linear) and module_param_name != "bias":
                # only add linear weights in gate function
                schedule_params["lr"] = base_lr * wg_lr_facetor
                schedule_params["weight_decay"] = weight_decay_wg

            elif isinstance(module, torch.nn.Embedding):
                schedule_params['weight_decay'] = weight_decay_embedding

            elif isinstance(module, norm_module_types):
                if not cfg.SOLVER.WEIGHT_DECAY_NORMBIAS_WEIGHT and module_param_name == "bias":
                    # ln bias use the same params as linear bias
                    schedule_params["lr"] = base_lr * bias_lr_factor
                    schedule_params['weight_decay'] = weight_decay_bias
                else:
                    schedule_params['weight_decay'] = weight_decay_norm

            elif module_param_name == "bias" or value.ndim == 1:
                schedule_params["lr"] = base_lr * bias_lr_factor
                schedule_params['weight_decay'] = weight_decay_bias

            params += [{
                "params": [value],
                "lr": max(schedule_params["lr"] * this_scale, cfg.LR_SCHEDULER.get('MIN_LR', 1e-6)),
                "moe": schedule_params['moe'],
                "weight_decay": schedule_params["weight_decay"],
                "name": f'{module_name}.{module_param_name}'
            }]



    return params


def create_group_moe_param_groups(
    model,
    base_lr: Optional[float] = None,
    weight_decay: Optional[float] = None,
    weight_decay_norm: Optional[float] = None,
    bias_lr_factor: Optional[float] = 1.0,
    wg_lr_facetor: Optional[float] = 1.0,
    weight_decay_bias: Optional[float] = None,
    weight_decay_embedding: Optional[float] = None,
    weight_decay_wg: Optional[float] = None,
    cfg: dict = None,
):
    from deepspeed.moe.utils import is_moe_param

    # params: List[Dict[str, Any]] = []
    memo: Set[torch.nn.parameter.Parameter] = set()

    if weight_decay_bias is None:
        weight_decay_bias = weight_decay
    norm_module_types = (
        torch.nn.BatchNorm1d,
        torch.nn.BatchNorm2d,
        torch.nn.BatchNorm3d,
        torch.nn.SyncBatchNorm,
        torch.nn.GroupNorm,
        torch.nn.InstanceNorm1d,
        torch.nn.InstanceNorm2d,
        torch.nn.InstanceNorm3d,
        torch.nn.LayerNorm,
        torch.nn.LocalResponseNorm,
    )

    group_params_dict = {}

    no_decay_list = {}
    if hasattr(model, 'no_weight_decay'):
        no_decay_list = model.no_weight_decay()

    wg_list = {}
    if hasattr(model, 'expert_gate_group'):
        wg_list = model.expert_gate_group()

    for module_name, module in model.named_modules():
        no_decay = False
        if module_name in no_decay_list:
            no_decay = True
        is_wg_param = False
        for wg_name in wg_list:
            if wg_name in module_name:
                is_wg_param = True
                continue

        for module_param_name, value in module.named_parameters(recurse=False):
            if not value.requires_grad:
                continue
            # Avoid duplicating parameters
            if value in memo:
                continue
            memo.add(value)

            # default setting
            lr_of_this_param = base_lr
            wd_of_this_param = weight_decay
            moe_of_this_param = False
            if is_moe_param(value):
                moe_of_this_param = True

            if no_decay or (module_param_name in no_decay_list):

                wd_of_this_param = 0.
            elif is_wg_param and isinstance(
                    module, torch.nn.Linear) and module_param_name != "bias":
                # only add linear weights in gate function
                lr_of_this_param = base_lr * wg_lr_facetor
                wd_of_this_param = weight_decay_wg

            elif isinstance(module, torch.nn.Embedding):
                wd_of_this_param = weight_decay_embedding

            elif isinstance(module, norm_module_types):
                if not cfg.SOLVER.WEIGHT_DECAY_NORMBIAS_WEIGHT and module_param_name == "bias":
                    # ln bias uses the same params as linear bias
                    lr_of_this_param = base_lr * bias_lr_factor
                    wd_of_this_param = weight_decay_bias
                else:
                    wd_of_this_param = weight_decay_norm

            elif module_param_name == "bias":
                lr_of_this_param = base_lr * bias_lr_factor
                wd_of_this_param = weight_decay_bias

            param_group_name = f'lr_{lr_of_this_param}_wd_{wd_of_this_param}_moe_{moe_of_this_param}'
            if param_group_name not in group_params_dict:
                group_params_dict[param_group_name] = {
                    'params': [],
                    "lr": lr_of_this_param,
                    "weight_decay": wd_of_this_param,
                    'moe': moe_of_this_param,
                    'name': param_group_name,
                    'params_name': [],
                }
            group_params_dict[param_group_name]['params'].append(value)
            group_params_dict[param_group_name]['params_name'].append(
                f'{module_name}.{module_param_name}')


    valid_params_groups =  list(group_params_dict.values())
    return valid_params_groups




def create_moe_param_groups(
    model,
    base_lr: Optional[float] = None,
    weight_decay: Optional[float] = None,
    weight_decay_norm: Optional[float] = None,
    bias_lr_factor: Optional[float] = 1.0,
    wg_lr_facetor: Optional[float] = 1.0,
    weight_decay_bias: Optional[float] = None,
    weight_decay_embedding: Optional[float] = None,
    weight_decay_wg: Optional[float] = None,

):
    from deepspeed.moe.utils import is_moe_param

    '''
    name: 
    '''
    if weight_decay_bias is None:
        weight_decay_bias = weight_decay
    norm_module_types = (
        torch.nn.BatchNorm1d,
        torch.nn.BatchNorm2d,
        torch.nn.BatchNorm3d,
        torch.nn.SyncBatchNorm,
        torch.nn.GroupNorm,
        torch.nn.InstanceNorm1d,
        torch.nn.InstanceNorm2d,
        torch.nn.InstanceNorm3d,
        torch.nn.LayerNorm,
        torch.nn.LocalResponseNorm,
    )

    if weight_decay_embedding == 0.0:
        norm_module_types = norm_module_types + (torch.nn.Embedding, )
    else:
        # if weight_decay_embedding is not 0.0, we set its weight_decay as normal weights
        # assert weight_decay_embedding == weight_decay
        pass



    params_with_weight_decay = {
        'params': [],
        'name': 'weight_decay_params',
        'params_name': [],
    }
    params_without_weight_decay = {
        'params': [],
        "weight_decay": 0.0,
        'name': 'without_weight_decay_params',
        'params_name': [],
    }
    bias_params = {
        'params': [],
        "lr": base_lr * bias_lr_factor,
        "weight_decay": weight_decay_bias,
        'name': 'bias_params',
        'params_name': [],
    }
    wg_params = {
        'params': [],
        "lr": base_lr * wg_lr_facetor,
        "weight_decay": weight_decay_wg,
        'name': 'wg_params',
        'params_name': [],
    }
    norm_params = {
        'params': [],
        "weight_decay": weight_decay_norm,
        'name': 'norm_params',
        'params_name': [],
    }
    moe_params_with_weight_decay = {
        'params': [],
        'moe': True,
        'name': 'weight_decay_moe_params',
        'params_name': [],
    }
    moe_params_without_weight_decay = {
        'params': [],
        "weight_decay": 0.0,
        'moe': True,
        'name': 'without_weight_decay_moe_params',
        'params_name': [],
    }
    moe_bias_params = {
        'params': [],
        "lr": base_lr * bias_lr_factor,
        "weight_decay": weight_decay_bias,
        'moe': True,
        'name': 'bias_moe_params',
        'params_name': [],
    }
    moe_norm_params = {
        'params': [],
        "weight_decay": weight_decay_norm,
        'moe': True,
        'name': 'norm_moe_params',
        'params_name': [],
    }

    params_groups = [
        params_with_weight_decay, params_without_weight_decay, norm_params, bias_params, wg_params, \
        moe_params_with_weight_decay, moe_params_without_weight_decay, moe_norm_params, moe_bias_params
    ]



    no_decay_list = {}
    if hasattr(model, 'no_weight_decay'):
        no_decay_list = model.no_weight_decay()

    wg_list = {}
    if hasattr(model, 'expert_gate_group'):
        wg_list = model.expert_gate_group()

    memo: Set[torch.nn.parameter.Parameter] = set()

    for module_name, module in model.named_modules():
        no_decay = False
        if module_name in no_decay_list:
            no_decay = True
        is_wg_param = False
        for wg_name in wg_list:
            if wg_name in module_name:
                is_wg_param = True
                continue

        for module_param_name, value in module.named_parameters(recurse=False):
            if not value.requires_grad:
                continue
            # Avoid duplicating parameters
            if value in memo:
                continue
            memo.add(value)
            if is_moe_param(value):
                if no_decay or (module_param_name in no_decay_list):
                    moe_params_without_weight_decay['params'].append(value)
                elif isinstance(module, norm_module_types):
                    moe_norm_params['params'].append(value)
                elif module_param_name == "bias":
                    moe_bias_params['params'].append(value)
                else:
                    moe_params_with_weight_decay['params'].append(value)
            else:
                if no_decay or (module_param_name in no_decay_list):
                    params_without_weight_decay['params'].append(value)
                    params_without_weight_decay['params_name'].append(f'{module_name}.{module_param_name}')
                elif is_wg_param and isinstance(module, torch.nn.Linear) and module_param_name != "bias":
                    # only add linear weights in gate function
                    wg_params['params'].append(value)
                    wg_params['params_name'].append(
                        f'{module_name}.{module_param_name}')
                elif isinstance(module, norm_module_types):
                    norm_params['params'].append(value)
                    norm_params['params_name'].append(
                        f'{module_name}.{module_param_name}')
                elif module_param_name == "bias":
                    bias_params['params'].append(value)
                    bias_params['params_name'].append(
                        f'{module_name}.{module_param_name}')
                else:
                    params_with_weight_decay['params'].append(value)
                    params_with_weight_decay['params_name'].append(
                        f'{module_name}.{module_param_name}')

    valid_params_groups = [
        group for group in params_groups if len(group['params']) > 0
    ]

    return  valid_params_groups






def _generate_optimizer_class_with_gradient_clipping(
    optimizer: Type[torch.optim.Optimizer],
    *,
    per_param_clipper: Optional[_GradientClipper] = None,
    global_clipper: Optional[_GradientClipper] = None,
) -> Type[torch.optim.Optimizer]:
    """
    Dynamically creates a new type that inherits the type of a given instance
    and overrides the `step` method to add gradient clipping
    """
    assert (
        per_param_clipper is None or global_clipper is None
    ), "Not allowed to use both per-parameter clipping and global clipping"

    def optimizer_wgc_step(self, closure=None):
        if per_param_clipper is not None:
            for group in self.param_groups:
                for p in group["params"]:
                    per_param_clipper(p)
        else:
            # global clipper for future use with detr
            # (https://github.com/facebookresearch/detr/pull/287)
            all_params = itertools.chain(*[g["params"] for g in self.param_groups])
            norm_before_clip = global_clipper(all_params)

        super(type(self), self).step(closure)

    OptimizerWithGradientClip = type(
        optimizer.__name__ + "WithGradientClip",
        (optimizer,),
        {"step": optimizer_wgc_step},
    )
    return OptimizerWithGradientClip

def maybe_add_gradient_clipping(
    cfg: CfgNode, optimizer: Type[torch.optim.Optimizer]
) -> Type[torch.optim.Optimizer]:
    """
    If gradient clipping is enabled through config options, wraps the existing
    optimizer type to become a new dynamically created class OptimizerWithGradientClip
    that inherits the given optimizer and overrides the `step` method to
    include gradient clipping.

    Args:
        cfg: CfgNode, configuration options
        optimizer: type. A subclass of torch.optim.Optimizer

    Return:
        type: either the input `optimizer` (if gradient clipping is disabled), or
            a subclass of it with gradient clipping included in the `step` method.
    """
    if cfg.SOLVER.GRAD_CLIP <= 0:
        return optimizer
    if isinstance(optimizer, torch.optim.Optimizer):
        optimizer_type = type(optimizer)
    else:
        assert issubclass(optimizer, torch.optim.Optimizer), optimizer
        optimizer_type = optimizer

    per_param_clipper, global_clipper = _create_gradient_clipper(cfg)
    OptimizerWithGradientClip = _generate_optimizer_class_with_gradient_clipping(
        optimizer_type, per_param_clipper=per_param_clipper, global_clipper=global_clipper
    )
    if isinstance(optimizer, torch.optim.Optimizer):
        optimizer.__class__ = OptimizerWithGradientClip  # a bit hacky, not recommended
        return optimizer
    else:
        return OptimizerWithGradientClip

def build_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimizer:
    """
    Build an optimizer from config.
    """
    # params = get_default_optimizer_params(
    #     model,
    #     base_lr=cfg.SOLVER.BASE_LR,
    #     weight_decay=cfg.SOLVER.WEIGHT_DECAY,
    #     weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
    #     bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
    #     weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
    # )
    params = create_seperate_moe_param_groups(
                    model,
                    base_lr=cfg.SOLVER.BASE_LR,
                    weight_decay=cfg.SOLVER.WEIGHT_DECAY,
                    weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
                    bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
                    wg_lr_facetor=cfg.SOLVER.WG_LR_FACTOR,
                    weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
                    weight_decay_embedding=cfg.SOLVER.WEIGHT_DECAY_EMBEDDING,
                    weight_decay_wg=cfg.SOLVER.WEIGHT_DECAY_WG,
                    cfg=cfg,
                )
    if cfg.SOLVER.NAME == 'LAMB':
        from uniperceiver.optim import LAMB
        optimizer = LAMB(
            params,
            lr=cfg.SOLVER.BASE_LR,
            betas=cfg.SOLVER.BETAS,
            eps=cfg.SOLVER.EPS,
            weight_decay=cfg.SOLVER.WEIGHT_DECAY, )

    else:
        optimizer = torch.optim.AdamW(
            params,
            lr=cfg.SOLVER.BASE_LR,
            betas=cfg.SOLVER.BETAS,
            eps=cfg.SOLVER.EPS,
            weight_decay=cfg.SOLVER.WEIGHT_DECAY, 
        )
    # optimizer = SOLVER_REGISTRY.get(cfg.SOLVER.NAME)
    # return maybe_add_gradient_clipping(cfg, optimizer)(cfg, params)
    return optimizer