|
import torch |
|
from torch import nn |
|
from configs.paths_config import model_paths |
|
from models.encoders.model_irse import Backbone |
|
|
|
|
|
class IDLoss(nn.Module): |
|
def __init__(self): |
|
super(IDLoss, self).__init__() |
|
print('Loading ResNet ArcFace') |
|
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') |
|
self.facenet.load_state_dict(torch.load(model_paths['ir_se50'])) |
|
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) |
|
self.facenet.eval() |
|
|
|
def extract_feats(self, x): |
|
x = x[:, :, 35:223, 32:220] |
|
x = self.face_pool(x) |
|
x_feats = self.facenet(x) |
|
return x_feats |
|
|
|
def forward(self, y_hat, y, x, label=None, weights=None): |
|
n_samples = x.shape[0] |
|
x_feats = self.extract_feats(x) |
|
y_feats = self.extract_feats(y) |
|
y_hat_feats = self.extract_feats(y_hat) |
|
y_feats = y_feats.detach() |
|
total_loss = 0 |
|
sim_improvement = 0 |
|
id_logs = [] |
|
count = 0 |
|
for i in range(n_samples): |
|
diff_target = y_hat_feats[i].dot(y_feats[i]) |
|
diff_input = y_hat_feats[i].dot(x_feats[i]) |
|
diff_views = y_feats[i].dot(x_feats[i]) |
|
|
|
if label is None: |
|
id_logs.append({'diff_target': float(diff_target), |
|
'diff_input': float(diff_input), |
|
'diff_views': float(diff_views)}) |
|
else: |
|
id_logs.append({f'diff_target_{label}': float(diff_target), |
|
f'diff_input_{label}': float(diff_input), |
|
f'diff_views_{label}': float(diff_views)}) |
|
|
|
loss = 1 - diff_target |
|
if weights is not None: |
|
loss = weights[i] * loss |
|
|
|
total_loss += loss |
|
id_diff = float(diff_target) - float(diff_views) |
|
sim_improvement += id_diff |
|
count += 1 |
|
|
|
return total_loss / count, sim_improvement / count, id_logs |
|
|