Spaces:
Runtime error
Runtime error
import torch.nn as nn | |
def change_num_input_channels(model, in_channels=1): | |
""" | |
Assumes number of input channels in model is 3. | |
""" | |
for i, m in enumerate(model.modules()): | |
if isinstance(m, (nn.Conv2d,nn.Conv3d)) and m.in_channels == 3: | |
m.in_channels = in_channels | |
# First, sum across channels | |
W = m.weight.sum(1, keepdim=True) | |
# Then, divide by number of channels | |
W = W / in_channels | |
# Then, repeat by number of channels | |
size = [1] * W.ndim | |
size[1] = in_channels | |
W = W.repeat(size) | |
m.weight = nn.Parameter(W) | |
break | |
return model | |
def change_initial_stride(model, stride, in_channels): | |
for i, m in enumerate(model.modules()): | |
if isinstance(m, (nn.Conv2d, nn.Conv3d)) and m.in_channels == in_channels: | |
m.stride = stride | |
break | |
return model |