English
naveensp commited on
Commit
be33a6c
·
verified ·
1 Parent(s): b24270f

Delete initialization.py

Browse files
Files changed (1) hide show
  1. initialization.py +0 -95
initialization.py DELETED
@@ -1,95 +0,0 @@
1
- import math
2
- from typing import Optional, Union
3
-
4
- import torch
5
- import torch.nn as nn
6
-
7
- from .config import InitFnType, ModelConfig
8
- from .util import StrEnum
9
-
10
- __all__ = ["init_weights", "ModuleType"]
11
-
12
-
13
- class ModuleType(StrEnum):
14
- in_module = "in"
15
- out_module = "out"
16
- emb = "emb"
17
- final_out = "final_out"
18
-
19
-
20
- def init_weights(
21
- config: ModelConfig,
22
- module: Union[nn.Linear, nn.Embedding],
23
- d: Optional[int] = None,
24
- layer_id: Optional[int] = None,
25
- std_factor: float = 1.0,
26
- type_of_module: Optional[ModuleType] = None,
27
- ) -> None:
28
- """
29
- Initialize weights of a linear or embedding module.
30
-
31
- :param config: The model config.
32
- :param module: The linear or embedding submodule to initialize.
33
- :param d: The effective input dimensionality of the weights. This could be smaller than the actual dimensions
34
- for fused layers.
35
- :param layer_id: When set, the standard deviation for the "mitchell" method will be adjusted by
36
- ``1 / sqrt(2 * (layer_id + 1))``.
37
- """
38
- d = d if d is not None else config.d_model
39
- if config.init_fn == InitFnType.normal:
40
- std = config.init_std * std_factor
41
- if config.init_cutoff_factor is not None:
42
- cutoff_value = config.init_cutoff_factor * std
43
- nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value)
44
- else:
45
- nn.init.normal_(module.weight, mean=0.0, std=std)
46
- elif config.init_fn == InitFnType.mitchell:
47
- std = std_factor / math.sqrt(d)
48
- if layer_id is not None:
49
- std = std / math.sqrt(2 * (layer_id + 1))
50
- nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-3 * std, b=3 * std)
51
- elif config.init_fn == InitFnType.kaiming_normal:
52
- nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
53
- elif config.init_fn == InitFnType.fan_in:
54
- std = std_factor / math.sqrt(d)
55
- nn.init.normal_(module.weight, mean=0.0, std=std)
56
- elif config.init_fn == InitFnType.full_megatron:
57
- if type_of_module is None:
58
- raise RuntimeError(f"When using the {InitFnType.full_megatron} init, every module must have a type.")
59
-
60
- cutoff_factor = config.init_cutoff_factor
61
- if cutoff_factor is None:
62
- cutoff_factor = 3
63
-
64
- if type_of_module == ModuleType.in_module:
65
- # for att_proj (same as QKV), ff_proj
66
- std = config.init_std
67
- elif type_of_module == ModuleType.out_module:
68
- # for attn_out, ff_out
69
- std = config.init_std / math.sqrt(2.0 * config.n_layers)
70
- elif type_of_module == ModuleType.emb:
71
- # positional embeddings (wpe)
72
- # token embeddings (wte)
73
- std = config.init_std
74
- elif type_of_module == ModuleType.final_out:
75
- # final output (ff_out)
76
- std = config.d_model**-0.5
77
- else:
78
- raise RuntimeError(f"Unknown module type '{type_of_module}'")
79
- nn.init.trunc_normal_(
80
- module.weight,
81
- mean=0.0,
82
- std=std,
83
- a=-cutoff_factor * std,
84
- b=cutoff_factor * std,
85
- )
86
- else:
87
- raise NotImplementedError(config.init_fn)
88
-
89
- if isinstance(module, nn.Linear):
90
- if module.bias is not None:
91
- nn.init.zeros_(module.bias)
92
-
93
- if config.init_fn == InitFnType.normal and getattr(module, "_is_residual", False):
94
- with torch.no_grad():
95
- module.weight.div_(math.sqrt(2 * config.n_layers))