File size: 174 Bytes
cc0dd3c
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
import torch
from torch import nn
from torch.nn import functional as F

def zero_module(module):
    for p in module.parameters():
        nn.init.zeros_(p)
    return module