Spaces:
Sleeping
Sleeping
from typing import Optional | |
import torch.nn as nn | |
def build_normalization(norm_type: str, dim: Optional[int] = None) -> nn.Module: | |
""" | |
Overview: | |
Construct the corresponding normalization module. For beginners, | |
refer to [this article](https://zhuanlan.zhihu.com/p/34879333) to learn more about batch normalization. | |
Arguments: | |
- norm_type (:obj:`str`): Type of the normalization. Currently supports ['BN', 'LN', 'IN', 'SyncBN']. | |
- dim (:obj:`Optional[int]`): Dimension of the normalization, applicable when norm_type is in ['BN', 'IN']. | |
Returns: | |
- norm_func (:obj:`nn.Module`): The corresponding batch normalization function. | |
""" | |
if dim is None: | |
key = norm_type | |
else: | |
if norm_type in ['BN', 'IN']: | |
key = norm_type + str(dim) | |
elif norm_type in ['LN', 'SyncBN']: | |
key = norm_type | |
else: | |
raise NotImplementedError("not support indicated dim when creates {}".format(norm_type)) | |
norm_func = { | |
'BN1': nn.BatchNorm1d, | |
'BN2': nn.BatchNorm2d, | |
'LN': nn.LayerNorm, | |
'IN1': nn.InstanceNorm1d, | |
'IN2': nn.InstanceNorm2d, | |
'SyncBN': nn.SyncBatchNorm, | |
} | |
if key in norm_func.keys(): | |
return norm_func[key] | |
else: | |
raise KeyError("invalid norm type: {}".format(key)) | |