|
|
|
|
|
|
|
|
|
import torch.nn as nn |
|
|
|
class ResnetDilated(nn.Module): |
|
def __init__(self, orig_resnet, dilate_scale=8): |
|
super(ResnetDilated, self).__init__() |
|
from functools import partial |
|
|
|
if dilate_scale == 8: |
|
orig_resnet.layer3.apply( |
|
partial(self._nostride_dilate, dilate=2)) |
|
orig_resnet.layer4.apply( |
|
partial(self._nostride_dilate, dilate=4)) |
|
elif dilate_scale == 16: |
|
orig_resnet.layer4.apply( |
|
partial(self._nostride_dilate, dilate=2)) |
|
|
|
self.conv1 = orig_resnet.conv1 |
|
self.bn1 = orig_resnet.bn1 |
|
self.relu = orig_resnet.relu |
|
|
|
self.maxpool = orig_resnet.maxpool |
|
self.layer1 = orig_resnet.layer1 |
|
self.layer2 = orig_resnet.layer2 |
|
self.layer3 = orig_resnet.layer3 |
|
self.layer4 = orig_resnet.layer4 |
|
|
|
def _nostride_dilate(self, m, dilate): |
|
classname = m.__class__.__name__ |
|
if classname.find('Conv') != -1: |
|
|
|
if m.stride == (2, 2): |
|
m.stride = (1, 1) |
|
if m.kernel_size == (3, 3): |
|
m.dilation = (dilate//2, dilate//2) |
|
m.padding = (dilate//2, dilate//2) |
|
|
|
else: |
|
if m.kernel_size == (3, 3): |
|
m.dilation = (dilate, dilate) |
|
m.padding = (dilate, dilate) |
|
|
|
def forward(self, x): |
|
x = self.relu(self.bn1(self.conv1(x))) |
|
x = self.maxpool(x) |
|
|
|
x = self.layer1(x) |
|
x = self.layer2(x) |
|
x = self.layer3(x) |
|
x = self.layer4(x) |
|
|
|
return x |
|
|