File size: 1,033 Bytes
5381499
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch


def init_module_weights(module):
    """Initialize the weights"""

    from src.model.modules import QuantizationModule

    # gumbel softmax requires special init
    if isinstance(module, QuantizationModule):
        module.weight_proj.weight.data.normal_(mean=0.0, std=1)
        module.weight_proj.bias.data.zero_()
        torch.nn.init.uniform_(module.codebooks)
    elif isinstance(module, torch.nn.Linear):
        # Slightly different from the TF version which uses truncated_normal for initialization
        # cf https://github.com/pytorch/pytorch/pull/5617
        module.weight.data.normal_(mean=0.0, std=0.5)
    elif isinstance(module, (torch.nn.LayerNorm, torch.nn.GroupNorm)):
        module.bias.data.zero_()
        module.weight.data.fill_(1.0)
    elif isinstance(module, torch.nn.Conv1d):
        torch.nn.init.kaiming_normal_(module.weight.data)

    if (
        isinstance(module, (torch.nn.Linear, torch.nn.Conv1d))
        and module.bias is not None
    ):
        module.bias.data.zero_()