Spaces:
Sleeping
Sleeping
import timm | |
import torch.nn as nn | |
class TimmUniversalEncoder(nn.Module): | |
def __init__(self, name, pretrained=True, in_channels=3, depth=5, output_stride=32): | |
super().__init__() | |
kwargs = dict( | |
in_chans=in_channels, | |
features_only=True, | |
output_stride=output_stride, | |
pretrained=pretrained, | |
out_indices=tuple(range(depth)), | |
) | |
# not all models support output stride argument, drop it by default | |
if output_stride == 32: | |
kwargs.pop("output_stride") | |
self.model = timm.create_model(name, **kwargs) | |
self._in_channels = in_channels | |
self._out_channels = [in_channels] + self.model.feature_info.channels() | |
self._depth = depth | |
self._output_stride = output_stride | |
def forward(self, x): | |
features = self.model(x) | |
features = [x] + features | |
return features | |
def out_channels(self): | |
return self._out_channels | |
def output_stride(self): | |
return min(self._output_stride, 2**self._depth) | |