# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Backbones from the TIMM library.""" from typing import List, Tuple import torch from timm.models import create_model from torch import nn class TimmBackbone(nn.Module): def __init__( self, name: str, features: Tuple[str, ...], ): super().__init__() out_indices = tuple(int(f[len("layer") :]) for f in features) backbone = create_model( name, pretrained=True, in_chans=3, features_only=True, out_indices=out_indices, ) num_channels = backbone.feature_info.channels() self.channel_list = num_channels[::-1] self.body = backbone def forward(self, x: torch.Tensor) -> List[torch.Tensor]: xs = self.body(x) out = [] for i, x in enumerate(xs): out.append(x) return out