Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
from utmosv2.dataset._utils import get_dataset_num | |
from utmosv2.model import MultiSpecExtModel, MultiSpecModelV2, SSLExtModel | |
class SSLMultiSpecExtModelV1(nn.Module): | |
def __init__(self, cfg): | |
super().__init__() | |
self.cfg = cfg | |
self.ssl = SSLExtModel(cfg) | |
self.spec_long = MultiSpecModelV2(cfg) | |
self.ssl.load_state_dict( | |
torch.load( | |
f"outputs/{cfg.model.ssl_spec.ssl_weight}/fold{cfg.now_fold}_s{cfg.split.seed}_best_model.pth" | |
) | |
) | |
self.spec_long.load_state_dict( | |
torch.load( | |
f"outputs/{cfg.model.ssl_spec.spec_weight}/fold{cfg.now_fold}_s{cfg.split.seed}_best_model.pth" | |
) | |
) | |
if cfg.model.ssl_spec.freeze: | |
for param in self.ssl.parameters(): | |
param.requires_grad = False | |
for param in self.spec_long.parameters(): | |
param.requires_grad = False | |
ssl_input = self.ssl.fc.in_features | |
spec_long_input = self.spec_long.fc.in_features | |
self.ssl.fc = nn.Identity() | |
self.spec_long.fc = nn.Identity() | |
self.num_dataset = get_dataset_num(cfg) | |
self.fc = nn.Linear( | |
ssl_input + spec_long_input + self.num_dataset, | |
cfg.model.ssl_spec.num_classes, | |
) | |
def forward(self, x1, x2, d): | |
x1 = self.ssl(x1, torch.zeros(x1.shape[0], self.num_dataset).to(x1.device)) | |
x2 = self.spec_long(x2) | |
x = torch.cat([x1, x2, d], dim=1) | |
x = self.fc(x) | |
return x | |
class SSLMultiSpecExtModelV2(nn.Module): | |
def __init__(self, cfg): | |
super().__init__() | |
self.cfg = cfg | |
self.ssl = SSLExtModel(cfg) | |
self.spec_long = MultiSpecExtModel(cfg) | |
if cfg.model.ssl_spec.ssl_weight is not None and cfg.phase == "train": | |
self.ssl.load_state_dict( | |
torch.load( | |
f"outputs/{cfg.model.ssl_spec.ssl_weight}/fold{cfg.now_fold}_s{cfg.split.seed}_best_model.pth" | |
) | |
) | |
if cfg.model.ssl_spec.spec_weight is not None and cfg.phase == "train": | |
self.spec_long.load_state_dict( | |
torch.load( | |
f"outputs/{cfg.model.ssl_spec.spec_weight}/fold{cfg.now_fold}_s{cfg.split.seed}_best_model.pth" | |
) | |
) | |
if cfg.model.ssl_spec.freeze: | |
for param in self.ssl.parameters(): | |
param.requires_grad = False | |
for param in self.spec_long.parameters(): | |
param.requires_grad = False | |
ssl_input = self.ssl.fc.in_features | |
spec_long_input = self.spec_long.fc.in_features | |
self.ssl.fc = nn.Identity() | |
self.spec_long.fc = nn.Identity() | |
self.num_dataset = get_dataset_num(cfg) | |
self.fc = nn.Linear( | |
ssl_input + spec_long_input + self.num_dataset, | |
cfg.model.ssl_spec.num_classes, | |
) | |
def forward(self, x1, x2, d): | |
x1 = self.ssl(x1, torch.zeros(x1.shape[0], self.num_dataset).to(x1.device)) | |
x2 = self.spec_long( | |
x2, torch.zeros(x1.shape[0], self.num_dataset).to(x1.device) | |
) | |
x = torch.cat([x1, x2, d], dim=1) | |
x = self.fc(x) | |
return x | |