|
import math |
|
import warnings |
|
from collections.abc import Sequence |
|
from functools import partial |
|
from typing import Any, Callable, Optional, Tuple, Union |
|
import torch |
|
from torch import nn |
|
from .fc import FC_CLASS_REGISTRY |
|
from .norm import NORM_CLASS_REGISTRY |
|
|
|
try: |
|
import transformer_engine.pytorch as te |
|
except: |
|
te = None |
|
|
|
|
|
def torch_default_param_init_fn_(module: nn.Module, **kwargs: Any) -> None: |
|
del kwargs |
|
if hasattr(module, "reset_parameters") and isinstance( |
|
module.reset_parameters, Callable |
|
): |
|
module.reset_parameters() |
|
|
|
|
|
def fused_init_helper_(module: nn.Module, init_fn_: Callable) -> None: |
|
_fused = getattr(module, "_fused", None) |
|
if _fused is None: |
|
raise RuntimeError(f"Internal logic error") |
|
assert isinstance(module.weight, torch.Tensor) |
|
(dim, splits) = _fused |
|
splits = (0, *splits, module.weight.size(dim)) |
|
for s, e in zip(splits[:-1], splits[1:]): |
|
slice_indices = [slice(None)] * module.weight.ndim |
|
slice_indices[dim] = slice(s, e) |
|
init_fn_(module.weight[slice_indices]) |
|
|
|
|
|
def generic_param_init_fn_( |
|
module: nn.Module, |
|
init_fn_: Callable, |
|
n_layers: int, |
|
d_model: Optional[int] = None, |
|
init_div_is_residual: Union[int, float, str, bool] = True, |
|
emb_init_std: Optional[float] = None, |
|
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, |
|
**kwargs: Any, |
|
) -> None: |
|
del kwargs |
|
init_div_is_residual = init_div_is_residual |
|
if init_div_is_residual is False: |
|
div_is_residual = 1.0 |
|
elif init_div_is_residual is True: |
|
div_is_residual = math.sqrt(2 * n_layers) |
|
elif isinstance(init_div_is_residual, float) or isinstance( |
|
init_div_is_residual, int |
|
): |
|
div_is_residual = init_div_is_residual |
|
elif init_div_is_residual.isnumeric(): |
|
div_is_residual = float(init_div_is_residual) |
|
else: |
|
div_is_residual = 1.0 |
|
raise ValueError( |
|
f"Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}" |
|
) |
|
if isinstance(module, tuple(set(FC_CLASS_REGISTRY.values()))): |
|
if hasattr(module, "_fused"): |
|
fused_init_helper_(module, init_fn_) |
|
else: |
|
init_fn_(module.weight) |
|
if module.bias is not None: |
|
assert isinstance(module.bias, torch.Tensor) |
|
torch.nn.init.zeros_(module.bias) |
|
if init_div_is_residual is not False and getattr(module, "_is_residual", False): |
|
with torch.no_grad(): |
|
module.weight.div_(div_is_residual) |
|
elif isinstance(module, nn.Embedding): |
|
if emb_init_std is not None: |
|
std = emb_init_std |
|
if std == 0: |
|
warnings.warn(f"Embedding layer initialized to 0.") |
|
emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std) |
|
elif emb_init_uniform_lim is not None: |
|
lim = emb_init_uniform_lim |
|
if isinstance(lim, Sequence): |
|
if len(lim) > 2: |
|
raise ValueError( |
|
f"Uniform init requires a min and a max limit. User input: {lim}." |
|
) |
|
if lim[0] == lim[1]: |
|
warnings.warn(f"Embedding layer initialized to {lim[0]}.") |
|
else: |
|
if lim == 0: |
|
warnings.warn(f"Embedding layer initialized to 0.") |
|
lim = [-lim, lim] |
|
(a, b) = lim |
|
emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b) |
|
else: |
|
emb_init_fn_ = init_fn_ |
|
emb_init_fn_(module.weight) |
|
elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))): |
|
if hasattr(module, "weight") and isinstance(module.weight, torch.Tensor): |
|
torch.nn.init.ones_(module.weight) |
|
if hasattr(module, "bias") and isinstance(module.bias, torch.Tensor): |
|
torch.nn.init.zeros_(module.bias) |
|
elif isinstance(module, nn.MultiheadAttention): |
|
if module._qkv_same_embed_dim: |
|
assert module.in_proj_weight is not None |
|
assert ( |
|
module.q_proj_weight is None |
|
and module.k_proj_weight is None |
|
and (module.v_proj_weight is None) |
|
) |
|
assert d_model is not None |
|
_d = d_model |
|
splits = (0, _d, 2 * _d, 3 * _d) |
|
for s, e in zip(splits[:-1], splits[1:]): |
|
init_fn_(module.in_proj_weight[s:e]) |
|
else: |
|
assert ( |
|
module.q_proj_weight is not None |
|
and module.k_proj_weight is not None |
|
and (module.v_proj_weight is not None) |
|
) |
|
assert module.in_proj_weight is None |
|
init_fn_(module.q_proj_weight) |
|
init_fn_(module.k_proj_weight) |
|
init_fn_(module.v_proj_weight) |
|
if module.in_proj_bias is not None: |
|
torch.nn.init.zeros_(module.in_proj_bias) |
|
if module.bias_k is not None: |
|
torch.nn.init.zeros_(module.bias_k) |
|
if module.bias_v is not None: |
|
torch.nn.init.zeros_(module.bias_v) |
|
init_fn_(module.out_proj.weight) |
|
if init_div_is_residual is not False and getattr( |
|
module.out_proj, "_is_residual", False |
|
): |
|
with torch.no_grad(): |
|
module.out_proj.weight.div_(div_is_residual) |
|
if module.out_proj.bias is not None: |
|
torch.nn.init.zeros_(module.out_proj.bias) |
|
elif te is not None and isinstance(module, te.LayerNormMLP): |
|
if isinstance(module.layer_norm_weight, torch.Tensor): |
|
torch.nn.init.ones_(module.layer_norm_weight) |
|
if isinstance(module.layer_norm_bias, torch.Tensor): |
|
torch.nn.init.zeros_(module.layer_norm_bias) |
|
init_fn_(module.fc1_weight) |
|
if module.fc1_bias is not None: |
|
assert isinstance(module.fc1_bias, torch.Tensor) |
|
torch.nn.init.zeros_(module.fc1_bias) |
|
init_fn_(module.fc2_weight) |
|
if module.fc2_bias is not None: |
|
assert isinstance(module.fc2_bias, torch.Tensor) |
|
torch.nn.init.zeros_(module.fc2_bias) |
|
with torch.no_grad(): |
|
module.fc2_weight.div_(div_is_residual) |
|
else: |
|
for _ in module.parameters(recurse=False): |
|
raise NotImplementedError( |
|
f"{module.__class__.__name__} parameters are not initialized by param_init_fn." |
|
) |
|
|
|
|
|
def _normal_init_(std: float, mean: float = 0.0) -> Callable: |
|
return partial(torch.nn.init.normal_, mean=mean, std=std) |
|
|
|
|
|
def _normal_param_init_fn_( |
|
module: nn.Module, |
|
std: float, |
|
n_layers: int, |
|
d_model: Optional[int] = None, |
|
init_div_is_residual: Union[int, float, str, bool] = True, |
|
emb_init_std: Optional[float] = None, |
|
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, |
|
**kwargs: Any, |
|
) -> None: |
|
del kwargs |
|
init_fn_ = _normal_init_(std=std) |
|
generic_param_init_fn_( |
|
module=module, |
|
init_fn_=init_fn_, |
|
d_model=d_model, |
|
n_layers=n_layers, |
|
init_div_is_residual=init_div_is_residual, |
|
emb_init_std=emb_init_std, |
|
emb_init_uniform_lim=emb_init_uniform_lim, |
|
) |
|
|
|
|
|
def baseline_param_init_fn_( |
|
module: nn.Module, |
|
init_std: Optional[float], |
|
n_layers: int, |
|
d_model: Optional[int] = None, |
|
init_div_is_residual: Union[int, float, str, bool] = True, |
|
emb_init_std: Optional[float] = None, |
|
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, |
|
**kwargs: Any, |
|
) -> None: |
|
del kwargs |
|
if init_std is None: |
|
raise ValueError( |
|
"You must set model.init_config['init_std'] to a float value to use the default initialization scheme." |
|
) |
|
_normal_param_init_fn_( |
|
module=module, |
|
std=init_std, |
|
d_model=d_model, |
|
n_layers=n_layers, |
|
init_div_is_residual=init_div_is_residual, |
|
emb_init_std=emb_init_std, |
|
emb_init_uniform_lim=emb_init_uniform_lim, |
|
) |
|
|
|
|
|
def small_param_init_fn_( |
|
module: nn.Module, |
|
n_layers: int, |
|
d_model: int, |
|
init_div_is_residual: Union[int, float, str, bool] = True, |
|
emb_init_std: Optional[float] = None, |
|
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, |
|
**kwargs: Any, |
|
) -> None: |
|
del kwargs |
|
std = math.sqrt(2 / (5 * d_model)) |
|
_normal_param_init_fn_( |
|
module=module, |
|
std=std, |
|
d_model=d_model, |
|
n_layers=n_layers, |
|
init_div_is_residual=init_div_is_residual, |
|
emb_init_std=emb_init_std, |
|
emb_init_uniform_lim=emb_init_uniform_lim, |
|
) |
|
|
|
|
|
def neox_param_init_fn_( |
|
module: nn.Module, |
|
n_layers: int, |
|
d_model: int, |
|
emb_init_std: Optional[float] = None, |
|
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, |
|
**kwargs: Any, |
|
) -> None: |
|
"""From section 2.3.1 of GPT-NeoX-20B: |
|
|
|
An Open-Source AutoregressiveLanguage Model — Black et. al. (2022) |
|
see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151 |
|
and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py |
|
""" |
|
del kwargs |
|
residual_div = n_layers / math.sqrt(10) |
|
small_param_init_fn_( |
|
module=module, |
|
d_model=d_model, |
|
n_layers=n_layers, |
|
init_div_is_residual=residual_div, |
|
emb_init_std=emb_init_std, |
|
emb_init_uniform_lim=emb_init_uniform_lim, |
|
) |
|
|
|
|
|
def kaiming_uniform_param_init_fn_( |
|
module: nn.Module, |
|
n_layers: int, |
|
d_model: Optional[int] = None, |
|
init_div_is_residual: Union[int, float, str, bool] = True, |
|
emb_init_std: Optional[float] = None, |
|
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, |
|
init_gain: float = 0, |
|
fan_mode: str = "fan_in", |
|
init_nonlinearity: str = "leaky_relu", |
|
**kwargs: Any, |
|
) -> None: |
|
del kwargs |
|
kaiming_uniform_ = partial( |
|
nn.init.kaiming_uniform_, |
|
a=init_gain, |
|
mode=fan_mode, |
|
nonlinearity=init_nonlinearity, |
|
) |
|
generic_param_init_fn_( |
|
module=module, |
|
init_fn_=kaiming_uniform_, |
|
d_model=d_model, |
|
n_layers=n_layers, |
|
init_div_is_residual=init_div_is_residual, |
|
emb_init_std=emb_init_std, |
|
emb_init_uniform_lim=emb_init_uniform_lim, |
|
) |
|
|
|
|
|
def kaiming_normal_param_init_fn_( |
|
module: nn.Module, |
|
n_layers: int, |
|
d_model: Optional[int] = None, |
|
init_div_is_residual: Union[int, float, str, bool] = True, |
|
emb_init_std: Optional[float] = None, |
|
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, |
|
init_gain: float = 0, |
|
fan_mode: str = "fan_in", |
|
init_nonlinearity: str = "leaky_relu", |
|
**kwargs: Any, |
|
) -> None: |
|
del kwargs |
|
kaiming_normal_ = partial( |
|
torch.nn.init.kaiming_normal_, |
|
a=init_gain, |
|
mode=fan_mode, |
|
nonlinearity=init_nonlinearity, |
|
) |
|
generic_param_init_fn_( |
|
module=module, |
|
init_fn_=kaiming_normal_, |
|
d_model=d_model, |
|
n_layers=n_layers, |
|
init_div_is_residual=init_div_is_residual, |
|
emb_init_std=emb_init_std, |
|
emb_init_uniform_lim=emb_init_uniform_lim, |
|
) |
|
|
|
|
|
def xavier_uniform_param_init_fn_( |
|
module: nn.Module, |
|
n_layers: int, |
|
d_model: Optional[int] = None, |
|
init_div_is_residual: Union[int, float, str, bool] = True, |
|
emb_init_std: Optional[float] = None, |
|
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, |
|
init_gain: float = 0, |
|
**kwargs: Any, |
|
) -> None: |
|
del kwargs |
|
xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain) |
|
generic_param_init_fn_( |
|
module=module, |
|
init_fn_=xavier_uniform_, |
|
d_model=d_model, |
|
n_layers=n_layers, |
|
init_div_is_residual=init_div_is_residual, |
|
emb_init_std=emb_init_std, |
|
emb_init_uniform_lim=emb_init_uniform_lim, |
|
) |
|
|
|
|
|
def xavier_normal_param_init_fn_( |
|
module: nn.Module, |
|
n_layers: int, |
|
d_model: Optional[int] = None, |
|
init_div_is_residual: Union[int, float, str, bool] = True, |
|
emb_init_std: Optional[float] = None, |
|
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, |
|
init_gain: float = 0, |
|
**kwargs: Any, |
|
) -> None: |
|
del kwargs |
|
xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain) |
|
generic_param_init_fn_( |
|
module=module, |
|
init_fn_=xavier_normal_, |
|
d_model=d_model, |
|
n_layers=n_layers, |
|
init_div_is_residual=init_div_is_residual, |
|
emb_init_std=emb_init_std, |
|
emb_init_uniform_lim=emb_init_uniform_lim, |
|
) |
|
|
|
|
|
MODEL_INIT_REGISTRY = { |
|
"default_": torch_default_param_init_fn_, |
|
"baseline_": baseline_param_init_fn_, |
|
"kaiming_uniform_": kaiming_uniform_param_init_fn_, |
|
"kaiming_normal_": kaiming_normal_param_init_fn_, |
|
"neox_init_": neox_param_init_fn_, |
|
"small_init_": small_param_init_fn_, |
|
"xavier_uniform_": xavier_uniform_param_init_fn_, |
|
"xavier_normal_": xavier_normal_param_init_fn_, |
|
} |
|
|