|
|
|
|
|
"""Utility function for weight initialization""" |
|
|
|
import torch.nn as nn |
|
from fvcore.nn.weight_init import c2_msra_fill |
|
|
|
|
|
def init_weights(model, fc_init_std=0.01, zero_init_final_bn=True): |
|
""" |
|
Performs ResNet style weight initialization. |
|
Args: |
|
fc_init_std (float): the expected standard deviation for fc layer. |
|
zero_init_final_bn (bool): if True, zero initialize the final bn for |
|
every bottleneck. |
|
""" |
|
for m in model.modules(): |
|
if isinstance(m, nn.Conv3d): |
|
""" |
|
Follow the initialization method proposed in: |
|
{He, Kaiming, et al. |
|
"Delving deep into rectifiers: Surpassing human-level |
|
performance on imagenet classification." |
|
arXiv preprint arXiv:1502.01852 (2015)} |
|
""" |
|
c2_msra_fill(m) |
|
elif isinstance(m, nn.BatchNorm3d): |
|
if ( |
|
hasattr(m, "transform_final_bn") |
|
and m.transform_final_bn |
|
and zero_init_final_bn |
|
): |
|
batchnorm_weight = 0.0 |
|
else: |
|
batchnorm_weight = 1.0 |
|
if m.weight is not None: |
|
m.weight.data.fill_(batchnorm_weight) |
|
if m.bias is not None: |
|
m.bias.data.zero_() |
|
if isinstance(m, nn.Linear): |
|
m.weight.data.normal_(mean=0.0, std=fc_init_std) |
|
if m.bias is not None: |
|
m.bias.data.zero_() |
|
|