|
|
|
|
|
|
|
|
|
|
|
|
|
"""Parameter initialization.""" |
|
|
|
import torch |
|
|
|
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm |
|
|
|
|
|
def initialize(model, init_type="pytorch"): |
|
"""Initialize Transformer module. |
|
|
|
:param torch.nn.Module model: transformer instance |
|
:param str init_type: initialization type |
|
""" |
|
if init_type == "pytorch": |
|
return |
|
|
|
|
|
for p in model.parameters(): |
|
if p.dim() > 1: |
|
if init_type == "xavier_uniform": |
|
torch.nn.init.xavier_uniform_(p.data) |
|
elif init_type == "xavier_normal": |
|
torch.nn.init.xavier_normal_(p.data) |
|
elif init_type == "kaiming_uniform": |
|
torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu") |
|
elif init_type == "kaiming_normal": |
|
torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu") |
|
else: |
|
raise ValueError("Unknown initialization: " + init_type) |
|
|
|
for p in model.parameters(): |
|
if p.dim() == 1: |
|
p.data.zero_() |
|
|
|
|
|
for m in model.modules(): |
|
if isinstance(m, (torch.nn.Embedding, LayerNorm)): |
|
m.reset_parameters() |
|
|