cc0dd3c
1
2
3
4
5
6
7
8
import torch from torch import nn from torch.nn import functional as F def zero_module(module): for p in module.parameters(): nn.init.zeros_(p) return module