File size: 882 Bytes
231edce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch.nn as nn


def change_num_input_channels(model, in_channels=1):
    """
    Assumes number of input channels in model is 3.
    """
    for i, m in enumerate(model.modules()):
      if isinstance(m, (nn.Conv2d,nn.Conv3d)) and m.in_channels == 3:
        m.in_channels = in_channels
        # First, sum across channels
        W = m.weight.sum(1, keepdim=True)
        # Then, divide by number of channels
        W = W / in_channels
        # Then, repeat by number of channels
        size = [1] * W.ndim
        size[1] = in_channels
        W = W.repeat(size)
        m.weight = nn.Parameter(W)
        break
    return model


def change_initial_stride(model, stride, in_channels):

    for i, m in enumerate(model.modules()):
      if isinstance(m, (nn.Conv2d, nn.Conv3d)) and m.in_channels == in_channels:
        m.stride = stride
        break
    return model