File size: 1,108 Bytes
2a13495 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import timm
import torch.nn as nn
class TimmUniversalEncoder(nn.Module):
def __init__(self, name, pretrained=True, in_channels=3, depth=5, output_stride=32):
super().__init__()
kwargs = dict(
in_chans=in_channels,
features_only=True,
output_stride=output_stride,
pretrained=pretrained,
out_indices=tuple(range(depth)),
)
# not all models support output stride argument, drop it by default
if output_stride == 32:
kwargs.pop("output_stride")
self.model = timm.create_model(name, **kwargs)
self._in_channels = in_channels
self._out_channels = [in_channels,] + self.model.feature_info.channels()
self._depth = depth
self._output_stride = output_stride
def forward(self, x):
features = self.model(x)
features = [x,] + features
return features
@property
def out_channels(self):
return self._out_channels
@property
def output_stride(self):
return min(self._output_stride, 2 ** self._depth)
|