Spaces:
Sleeping
Sleeping
File size: 1,942 Bytes
62ef5f4 3d85088 62ef5f4 3d85088 62ef5f4 3d85088 62ef5f4 3d85088 62ef5f4 3d85088 62ef5f4 3d85088 62ef5f4 3d85088 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
from torch import nn
from timm import create_model
from torchvision.transforms import Normalize
class SwinModel(nn.Module):
def __init__(self, pretrained_model="swinv2-cr-t-224", device="cuda") -> None:
"""
vit_tiny_patch16_224.augreg_in21k_ft_in1k
swinv2_cr_tiny_ns_224.sw_in1k
"""
super().__init__()
self.device = device
self.pretrained_model = pretrained_model
if pretrained_model == "swinv2-cr-t-224":
self.pretrained = create_model(
"swinv2_cr_tiny_ns_224.sw_in1k",
pretrained=True,
features_only=True,
out_indices=[-4, -3, -2, -1],
).to(device)
elif pretrained_model == "swinv2-t-256":
self.pretrained = create_model(
"swinv2_tiny_window16_256.ms_in1k",
pretrained=True,
features_only=True,
out_indices=[-4, -3, -2, -1],
).to(device)
elif pretrained_model == "swinv2-cr-s-224":
self.pretrained = create_model(
"swinv2_cr_small_ns_224.sw_in1k",
pretrained=True,
features_only=True,
out_indices=[-4, -3, -2, -1],
).to(device)
else:
raise NotImplementedError
self.pretrained.eval()
self.normalizer = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
self.upsample = nn.Upsample(scale_factor=2)
for params in self.pretrained.parameters():
params.requires_grad = False
def forward(self, x):
outputs = self.pretrained(x)
if self.pretrained_model in ["swinv2-t-256"]:
for i in range(len(outputs)):
outputs[i] = outputs[i].permute(0, 3, 1, 2) # Change channel-last to channel-first
outputs = [self.upsample(feat) for feat in outputs]
return outputs |