UTMOSv2 / utmosv2 /model /ssl_multispec.py
kAIto47802
Resolved conflict in README.md
b55d767
import torch
import torch.nn as nn
from utmosv2.dataset._utils import get_dataset_num
from utmosv2.model import MultiSpecExtModel, MultiSpecModelV2, SSLExtModel
class SSLMultiSpecExtModelV1(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.ssl = SSLExtModel(cfg)
self.spec_long = MultiSpecModelV2(cfg)
self.ssl.load_state_dict(
torch.load(
f"outputs/{cfg.model.ssl_spec.ssl_weight}/fold{cfg.now_fold}_s{cfg.split.seed}_best_model.pth"
)
)
self.spec_long.load_state_dict(
torch.load(
f"outputs/{cfg.model.ssl_spec.spec_weight}/fold{cfg.now_fold}_s{cfg.split.seed}_best_model.pth"
)
)
if cfg.model.ssl_spec.freeze:
for param in self.ssl.parameters():
param.requires_grad = False
for param in self.spec_long.parameters():
param.requires_grad = False
ssl_input = self.ssl.fc.in_features
spec_long_input = self.spec_long.fc.in_features
self.ssl.fc = nn.Identity()
self.spec_long.fc = nn.Identity()
self.num_dataset = get_dataset_num(cfg)
self.fc = nn.Linear(
ssl_input + spec_long_input + self.num_dataset,
cfg.model.ssl_spec.num_classes,
)
def forward(self, x1, x2, d):
x1 = self.ssl(x1, torch.zeros(x1.shape[0], self.num_dataset).to(x1.device))
x2 = self.spec_long(x2)
x = torch.cat([x1, x2, d], dim=1)
x = self.fc(x)
return x
class SSLMultiSpecExtModelV2(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.ssl = SSLExtModel(cfg)
self.spec_long = MultiSpecExtModel(cfg)
if cfg.model.ssl_spec.ssl_weight is not None and cfg.phase == "train":
self.ssl.load_state_dict(
torch.load(
f"outputs/{cfg.model.ssl_spec.ssl_weight}/fold{cfg.now_fold}_s{cfg.split.seed}_best_model.pth"
)
)
if cfg.model.ssl_spec.spec_weight is not None and cfg.phase == "train":
self.spec_long.load_state_dict(
torch.load(
f"outputs/{cfg.model.ssl_spec.spec_weight}/fold{cfg.now_fold}_s{cfg.split.seed}_best_model.pth"
)
)
if cfg.model.ssl_spec.freeze:
for param in self.ssl.parameters():
param.requires_grad = False
for param in self.spec_long.parameters():
param.requires_grad = False
ssl_input = self.ssl.fc.in_features
spec_long_input = self.spec_long.fc.in_features
self.ssl.fc = nn.Identity()
self.spec_long.fc = nn.Identity()
self.num_dataset = get_dataset_num(cfg)
self.fc = nn.Linear(
ssl_input + spec_long_input + self.num_dataset,
cfg.model.ssl_spec.num_classes,
)
def forward(self, x1, x2, d):
x1 = self.ssl(x1, torch.zeros(x1.shape[0], self.num_dataset).to(x1.device))
x2 = self.spec_long(
x2, torch.zeros(x1.shape[0], self.num_dataset).to(x1.device)
)
x = torch.cat([x1, x2, d], dim=1)
x = self.fc(x)
return x