Spaces:
Sleeping
Sleeping
from . import _utils as utils | |
class EncoderMixin: | |
"""Add encoder functionality such as: | |
- output channels specification of feature tensors (produced by encoder) | |
- patching first convolution for arbitrary input channels | |
""" | |
_output_stride = 32 | |
def out_channels(self): | |
"""Return channels dimensions for each tensor of forward output of encoder""" | |
return self._out_channels[: self._depth + 1] | |
def output_stride(self): | |
return min(self._output_stride, 2**self._depth) | |
def set_in_channels(self, in_channels, pretrained=True): | |
"""Change first convolution channels""" | |
if in_channels == 3: | |
return | |
self._in_channels = in_channels | |
if self._out_channels[0] == 3: | |
self._out_channels = tuple([in_channels] + list(self._out_channels)[1:]) | |
utils.patch_first_conv( | |
model=self, new_in_channels=in_channels, pretrained=pretrained | |
) | |
def get_stages(self): | |
"""Override it in your implementation""" | |
raise NotImplementedError | |
def make_dilated(self, output_stride): | |
if output_stride == 16: | |
stage_list = [5] | |
dilation_list = [2] | |
elif output_stride == 8: | |
stage_list = [4, 5] | |
dilation_list = [2, 4] | |
else: | |
raise ValueError( | |
"Output stride should be 16 or 8, got {}.".format(output_stride) | |
) | |
self._output_stride = output_stride | |
stages = self.get_stages() | |
for stage_indx, dilation_rate in zip(stage_list, dilation_list): | |
utils.replace_strides_with_dilation( | |
module=stages[stage_indx], dilation_rate=dilation_rate | |
) | |