Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.models as models | |
import numpy as np | |
from modelscope.msdatasets import MsDataset | |
class Interpolate(nn.Module): | |
def __init__( | |
self, | |
size=None, | |
scale_factor=None, | |
mode="bilinear", | |
align_corners=False, | |
): | |
super(Interpolate, self).__init__() | |
self.size = size | |
self.scale_factor = scale_factor | |
self.mode = mode | |
self.align_corners = align_corners | |
def forward(self, x): | |
return F.interpolate( | |
x, | |
size=self.size, | |
scale_factor=self.scale_factor, | |
mode=self.mode, | |
align_corners=self.align_corners, | |
) | |
class EvalNet: | |
def __init__( | |
self, | |
backbone: str, | |
cls_num: int, | |
ori_T: int, | |
imgnet_ver="v1", | |
weight_path="", | |
): | |
if not hasattr(models, backbone): | |
raise ValueError(f"Unsupported model {backbone}.") | |
self.imgnet_ver = imgnet_ver | |
self.training = bool(weight_path == "") | |
self.type, self.weight_url, self.input_size = self._model_info(backbone) | |
self.model: torch.nn.Module = eval("models.%s()" % backbone) | |
self.ori_T = ori_T | |
self.out_channel_before_classifier = 0 | |
self._set_channel_outsize() # set out channel size | |
self.cls_num = cls_num | |
self._set_classifier() | |
self._pseudo_foward() | |
checkpoint = ( | |
torch.load(weight_path) | |
if torch.cuda.is_available() | |
else torch.load(weight_path, map_location="cpu") | |
) | |
if self.type == "squeezenet": | |
self.model.load_state_dict(checkpoint, False) | |
else: | |
self.model.load_state_dict(checkpoint["model"], False) | |
self.classifier.load_state_dict(checkpoint["classifier"], False) | |
if torch.cuda.is_available(): | |
self.model = self.model.cuda() | |
self.classifier = self.classifier.cuda() | |
self.model.eval() | |
def _get_backbone(self, backbone_ver, backbone_list): | |
for backbone_info in backbone_list: | |
if backbone_ver == backbone_info["ver"]: | |
return backbone_info | |
raise ValueError("[Backbone not found] Please check if --model is correct!") | |
def _model_info(self, backbone: str): | |
backbone_list = MsDataset.load( | |
"monetjoe/cv_backbones", | |
split=self.imgnet_ver, | |
cache_dir="./__pycache__", | |
trust_remote_code=True, | |
) | |
backbone_info = self._get_backbone(backbone, backbone_list) | |
return ( | |
str(backbone_info["type"]), | |
str(backbone_info["url"]), | |
int(backbone_info["input_size"]), | |
) | |
def _create_classifier(self): | |
original_T_size = self.ori_T | |
return nn.Sequential( | |
nn.AdaptiveAvgPool2d((1, None)), # F -> 1 | |
nn.ConvTranspose2d( | |
self.out_channel_before_classifier, | |
256, | |
kernel_size=(1, 4), | |
stride=(1, 2), | |
padding=(0, 1), | |
), | |
nn.ReLU(inplace=True), | |
nn.BatchNorm2d(256), | |
nn.ConvTranspose2d( | |
256, 128, kernel_size=(1, 4), stride=(1, 2), padding=(0, 1) | |
), | |
nn.ReLU(inplace=True), | |
nn.BatchNorm2d(128), | |
nn.ConvTranspose2d( | |
128, 64, kernel_size=(1, 4), stride=(1, 2), padding=(0, 1) | |
), | |
nn.ReLU(inplace=True), | |
nn.BatchNorm2d(64), | |
nn.ConvTranspose2d( | |
64, 32, kernel_size=(1, 4), stride=(1, 2), padding=(0, 1) | |
), | |
nn.ReLU(inplace=True), | |
nn.BatchNorm2d(32), # input for Interp: [bsz, C, 1, T] | |
Interpolate( | |
size=(1, original_T_size), mode="bilinear", align_corners=False | |
), # classifier | |
nn.Conv2d(32, 32, kernel_size=(1, 1)), | |
nn.ReLU(inplace=True), | |
nn.BatchNorm2d(32), | |
nn.Conv2d(32, self.cls_num, kernel_size=(1, 1)), | |
) | |
def _set_channel_outsize(self): #### get the output size before classifier #### | |
conv2d_out_ch = [] | |
for name, module in self.model.named_modules(): | |
if isinstance(module, torch.nn.Conv2d): | |
conv2d_out_ch.append(module.out_channels) | |
if ( | |
str(name).__contains__("classifier") | |
or str(name).__eq__("fc") | |
or str(name).__contains__("head") | |
): | |
if isinstance(module, torch.nn.Conv2d): | |
conv2d_out_ch.append(module.in_channels) | |
break | |
self.out_channel_before_classifier = conv2d_out_ch[-1] | |
def _set_classifier(self): #### set custom classifier #### | |
if self.type == "resnet": | |
self.model.avgpool = nn.Identity() | |
self.model.fc = nn.Identity() | |
self.classifier = self._create_classifier() | |
elif ( | |
self.type == "vgg" or self.type == "efficientnet" or self.type == "convnext" | |
): | |
self.model.avgpool = nn.Identity() | |
self.model.classifier = nn.Identity() | |
self.classifier = self._create_classifier() | |
elif self.type == "squeezenet": | |
self.model.classifier = nn.Identity() | |
self.classifier = self._create_classifier() | |
def get_input_size(self): | |
return self.input_size | |
def _pseudo_foward(self): | |
temp = torch.randn(4, 3, self.input_size, self.input_size) | |
out = self.model(temp) | |
self.H = int(np.sqrt(out.size(1) / self.out_channel_before_classifier)) | |
def forward(self, x): | |
if torch.cuda.is_available(): | |
x = x.cuda() | |
if self.type == "convnext": | |
out = self.model(x) | |
return self.classifier(out).squeeze() | |
else: | |
out = self.model(x) | |
out = out.view( | |
out.size(0), self.out_channel_before_classifier, self.H, self.H | |
) | |
return self.classifier(out).squeeze() | |
class t_EvalNet: | |
def __init__( | |
self, | |
backbone: str, | |
cls_num: int, | |
ori_T: int, | |
imgnet_ver="v1", | |
weight_path="", | |
): | |
if not hasattr(models, backbone): | |
raise ValueError(f"Unsupported model {backbone}.") | |
self.imgnet_ver = imgnet_ver | |
self.type, self.weight_url, self.input_size = self._model_info(backbone) | |
self.model: torch.nn.Module = eval("models.%s()" % backbone) | |
self.ori_T = ori_T | |
if self.type == "vit": | |
self.hidden_dim = self.model.hidden_dim | |
self.class_token = nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) | |
elif self.type == "swin_transformer": | |
self.hidden_dim = 768 | |
self.cls_num = cls_num | |
self._set_classifier() | |
checkpoint = ( | |
torch.load(weight_path) | |
if torch.cuda.is_available() | |
else torch.load(weight_path, map_location="cpu") | |
) | |
self.model.load_state_dict(checkpoint["model"], False) | |
self.classifier.load_state_dict(checkpoint["classifier"], False) | |
if torch.cuda.is_available(): | |
self.model = self.model.cuda() | |
self.classifier = self.classifier.cuda() | |
self.model.eval() | |
def _get_backbone(self, backbone_ver, backbone_list): | |
for backbone_info in backbone_list: | |
if backbone_ver == backbone_info["ver"]: | |
return backbone_info | |
raise ValueError("[Backbone not found] Please check if --model is correct!") | |
def _model_info(self, backbone: str): | |
backbone_list = MsDataset.load( | |
"monetjoe/cv_backbones", | |
split=self.imgnet_ver, | |
cache_dir="./__pycache__", | |
trust_remote_code=True, | |
) | |
backbone_info = self._get_backbone(backbone, backbone_list) | |
return ( | |
str(backbone_info["type"]), | |
str(backbone_info["url"]), | |
int(backbone_info["input_size"]), | |
) | |
def _create_classifier(self): | |
original_T_size = self.ori_T | |
self.avgpool = nn.AdaptiveAvgPool2d((1, None)) # F -> 1 | |
return nn.Sequential( # nn.AdaptiveAvgPool2d((1, None)), # F -> 1 | |
nn.ConvTranspose2d( | |
self.hidden_dim, 256, kernel_size=(1, 4), stride=(1, 2), padding=(0, 1) | |
), | |
nn.ReLU(inplace=True), | |
nn.BatchNorm2d(256), | |
nn.ConvTranspose2d( | |
256, 128, kernel_size=(1, 4), stride=(1, 2), padding=(0, 1) | |
), | |
nn.ReLU(inplace=True), | |
nn.BatchNorm2d(128), | |
nn.ConvTranspose2d( | |
128, 64, kernel_size=(1, 4), stride=(1, 2), padding=(0, 1) | |
), | |
nn.ReLU(inplace=True), | |
nn.BatchNorm2d(64), | |
nn.ConvTranspose2d( | |
64, 32, kernel_size=(1, 4), stride=(1, 2), padding=(0, 1) | |
), | |
nn.ReLU(inplace=True), | |
nn.BatchNorm2d(32), # input for Interp: [bsz, C, 1, T] | |
Interpolate( | |
size=(1, original_T_size), mode="bilinear", align_corners=False | |
), # classifier | |
nn.Conv2d(32, 32, kernel_size=(1, 1)), | |
nn.ReLU(inplace=True), | |
nn.BatchNorm2d(32), | |
nn.Conv2d(32, self.cls_num, kernel_size=(1, 1)), | |
) | |
def _set_classifier(self): #### set custom classifier #### | |
if self.type == "vit" or self.type == "swin_transformer": | |
self.classifier = self._create_classifier() | |
def get_input_size(self): | |
return self.input_size | |
def forward(self, x: torch.Tensor): | |
if torch.cuda.is_available(): | |
x = x.cuda() | |
self.class_token = self.class_token.cuda() | |
if self.type == "vit": | |
x = self.model._process_input(x) | |
batch_class_token = self.class_token.expand(x.size(0), -1, -1) | |
x = torch.cat([batch_class_token, x], dim=1) | |
x = self.model.encoder(x) | |
x = x[:, 1:].permute(0, 2, 1) | |
x = x.unsqueeze(2) | |
return self.classifier(x).squeeze() | |
elif self.type == "swin_transformer": | |
x = self.model.features(x) # [B, H, W, C] | |
x = x.permute(0, 3, 1, 2) | |
x = self.avgpool(x) # [B, C, 1, W] | |
return self.classifier(x).squeeze() | |
return None | |