ianpan's picture
Initial commit
231edce
raw
history blame contribute delete
882 Bytes
import torch.nn as nn
def change_num_input_channels(model, in_channels=1):
"""
Assumes number of input channels in model is 3.
"""
for i, m in enumerate(model.modules()):
if isinstance(m, (nn.Conv2d,nn.Conv3d)) and m.in_channels == 3:
m.in_channels = in_channels
# First, sum across channels
W = m.weight.sum(1, keepdim=True)
# Then, divide by number of channels
W = W / in_channels
# Then, repeat by number of channels
size = [1] * W.ndim
size[1] = in_channels
W = W.repeat(size)
m.weight = nn.Parameter(W)
break
return model
def change_initial_stride(model, stride, in_channels):
for i, m in enumerate(model.modules()):
if isinstance(m, (nn.Conv2d, nn.Conv3d)) and m.in_channels == in_channels:
m.stride = stride
break
return model