|
import torch.nn as nn |
|
import torch.nn.init as init |
|
|
|
def weight_init(m, spread=1.0): |
|
''' |
|
Initializes a model's parameters. |
|
Credits to: https://gist.github.com/jeasinema |
|
|
|
Usage: |
|
model = Model() |
|
model.apply(weight_init) |
|
''' |
|
if isinstance(m, nn.Conv1d): |
|
init.normal_(m.weight.data, mean=0, std=spread) |
|
if m.bias is not None: |
|
init.normal_(m.bias.data, mean=0, std=spread) |
|
elif isinstance(m, nn.Conv2d): |
|
init.xavier_normal_(m.weight.data, gain=spread) |
|
if m.bias is not None: |
|
init.normal_(m.bias.data, mean=0, std=spread) |
|
elif isinstance(m, nn.Conv3d): |
|
init.xavier_normal_(m.weight.data, gain=spread) |
|
if m.bias is not None: |
|
init.normal_(m.bias.data, mean=0, std=spread) |
|
elif isinstance(m, nn.ConvTranspose1d): |
|
init.normal_(m.weight.data, mean=0, std=spread) |
|
if m.bias is not None: |
|
init.normal_(m.bias.data, mean=0, std=spread) |
|
elif isinstance(m, nn.ConvTranspose2d): |
|
init.xavier_normal_(m.weight.data, gain=spread) |
|
if m.bias is not None: |
|
init.normal_(m.bias.data, mean=0, std=spread) |
|
elif isinstance(m, nn.ConvTranspose3d): |
|
init.xavier_normal_(m.weight.data, gain=spread) |
|
if m.bias is not None: |
|
init.normal_(m.bias.data, mean=0, std=spread) |
|
elif isinstance(m, nn.BatchNorm1d): |
|
init.normal_(m.weight.data, mean=0, std=spread) |
|
init.constant_(m.bias.data, 0) |
|
elif isinstance(m, nn.BatchNorm2d): |
|
init.normal_(m.weight.data, mean=0, std=spread) |
|
init.constant_(m.bias.data, 0) |
|
elif isinstance(m, nn.BatchNorm3d): |
|
init.normal_(m.weight.data, mean=0, std=spread) |
|
init.constant_(m.bias.data, 0) |
|
elif isinstance(m, nn.Linear): |
|
init.xavier_normal_(m.weight.data, gain=spread) |
|
try: |
|
init.normal_(m.bias.data, mean=0, std=spread) |
|
except AttributeError: |
|
pass |
|
elif isinstance(m, nn.LSTM): |
|
for param in m.parameters(): |
|
if len(param.shape) >= 2: |
|
init.orthogonal_(param.data) |
|
else: |
|
init.normal_(param.data, mean=0, std=spread) |
|
elif isinstance(m, nn.LSTMCell): |
|
for param in m.parameters(): |
|
if len(param.shape) >= 2: |
|
init.orthogonal_(param.data) |
|
else: |
|
init.normal_(param.data, mean=0, std=spread) |
|
elif isinstance(m, nn.GRU): |
|
for param in m.parameters(): |
|
if len(param.shape) >= 2: |
|
init.orthogonal_(param.data) |
|
else: |
|
init.normal_(param.data, mean=0, std=spread) |
|
elif isinstance(m, nn.GRUCell): |
|
for param in m.parameters(): |
|
if len(param.shape) >= 2: |
|
init.orthogonal_(param.data) |
|
else: |
|
init.normal_(param.data, mean=0, std=spread) |