English
naveensp commited on
Commit
11f684e
·
verified ·
1 Parent(s): 4a1fa64

Upload initialization.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. initialization.py +95 -0
initialization.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from .config import InitFnType, ModelConfig
8
+ from .util import StrEnum
9
+
10
+ __all__ = ["init_weights", "ModuleType"]
11
+
12
+
13
+ class ModuleType(StrEnum):
14
+ in_module = "in"
15
+ out_module = "out"
16
+ emb = "emb"
17
+ final_out = "final_out"
18
+
19
+
20
+ def init_weights(
21
+ config: ModelConfig,
22
+ module: Union[nn.Linear, nn.Embedding],
23
+ d: Optional[int] = None,
24
+ layer_id: Optional[int] = None,
25
+ std_factor: float = 1.0,
26
+ type_of_module: Optional[ModuleType] = None,
27
+ ) -> None:
28
+ """
29
+ Initialize weights of a linear or embedding module.
30
+
31
+ :param config: The model config.
32
+ :param module: The linear or embedding submodule to initialize.
33
+ :param d: The effective input dimensionality of the weights. This could be smaller than the actual dimensions
34
+ for fused layers.
35
+ :param layer_id: When set, the standard deviation for the "mitchell" method will be adjusted by
36
+ ``1 / sqrt(2 * (layer_id + 1))``.
37
+ """
38
+ d = d if d is not None else config.d_model
39
+ if config.init_fn == InitFnType.normal:
40
+ std = config.init_std * std_factor
41
+ if config.init_cutoff_factor is not None:
42
+ cutoff_value = config.init_cutoff_factor * std
43
+ nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value)
44
+ else:
45
+ nn.init.normal_(module.weight, mean=0.0, std=std)
46
+ elif config.init_fn == InitFnType.mitchell:
47
+ std = std_factor / math.sqrt(d)
48
+ if layer_id is not None:
49
+ std = std / math.sqrt(2 * (layer_id + 1))
50
+ nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-3 * std, b=3 * std)
51
+ elif config.init_fn == InitFnType.kaiming_normal:
52
+ nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
53
+ elif config.init_fn == InitFnType.fan_in:
54
+ std = std_factor / math.sqrt(d)
55
+ nn.init.normal_(module.weight, mean=0.0, std=std)
56
+ elif config.init_fn == InitFnType.full_megatron:
57
+ if type_of_module is None:
58
+ raise RuntimeError(f"When using the {InitFnType.full_megatron} init, every module must have a type.")
59
+
60
+ cutoff_factor = config.init_cutoff_factor
61
+ if cutoff_factor is None:
62
+ cutoff_factor = 3
63
+
64
+ if type_of_module == ModuleType.in_module:
65
+ # for att_proj (same as QKV), ff_proj
66
+ std = config.init_std
67
+ elif type_of_module == ModuleType.out_module:
68
+ # for attn_out, ff_out
69
+ std = config.init_std / math.sqrt(2.0 * config.n_layers)
70
+ elif type_of_module == ModuleType.emb:
71
+ # positional embeddings (wpe)
72
+ # token embeddings (wte)
73
+ std = config.init_std
74
+ elif type_of_module == ModuleType.final_out:
75
+ # final output (ff_out)
76
+ std = config.d_model**-0.5
77
+ else:
78
+ raise RuntimeError(f"Unknown module type '{type_of_module}'")
79
+ nn.init.trunc_normal_(
80
+ module.weight,
81
+ mean=0.0,
82
+ std=std,
83
+ a=-cutoff_factor * std,
84
+ b=cutoff_factor * std,
85
+ )
86
+ else:
87
+ raise NotImplementedError(config.init_fn)
88
+
89
+ if isinstance(module, nn.Linear):
90
+ if module.bias is not None:
91
+ nn.init.zeros_(module.bias)
92
+
93
+ if config.init_fn == InitFnType.normal and getattr(module, "_is_residual", False):
94
+ with torch.no_grad():
95
+ module.weight.div_(math.sqrt(2 * config.n_layers))