Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from criteria.model_irse import Backbone | |
from criteria.backbones import get_model | |
class IDLoss(nn.Module): | |
""" | |
Computes a cosine similarity between people in two images. | |
Taken from TreB1eN's [1] implementation of InsightFace [2, 3], as used in pixel2style2pixel [4]. | |
[1] https://github.com/TreB1eN/InsightFace_Pytorch | |
[2] https://github.com/deepinsight/insightface | |
[3] Deng, Jiankang and Guo, Jia and Niannan, Xue and Zafeiriou, Stefanos. | |
ArcFace: Additive Angular Margin Loss for Deep Face Recognition. In CVPR, 2019 | |
[4] https://github.com/eladrich/pixel2style2pixel | |
""" | |
def __init__(self, model_path, official=False): | |
""" | |
Arguments: | |
model_path (str): Path to IR-SE50 model. | |
""" | |
super(IDLoss, self).__init__() | |
print("Loading ResNet ArcFace") | |
self.official = official | |
if official: | |
self.facenet = get_model("r100", fp16=False) | |
else: | |
self.facenet = Backbone( | |
input_size=112, num_layers=50, drop_ratio=0.6, mode="ir_se" | |
) | |
self.facenet.load_state_dict(torch.load(model_path)) | |
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) | |
self.facenet.eval() | |
def extract_feats(self, x): | |
x = x[:, :, 35:223, 32:220] # Crop interesting region | |
x = self.face_pool(x) | |
x_feats = self.facenet(x) | |
return x_feats | |
def forward(self, x, y): | |
""" | |
Arguments: | |
x (Tensor): The batch of original images | |
y (Tensor): The batch of generated images | |
Returns: | |
loss (Tensor): Cosine similarity between the | |
features of the original and generated images. | |
""" | |
x_feats = self.extract_feats(x) | |
y_feats = self.extract_feats(y) | |
if self.official: | |
x_feats = F.normalize(x_feats) | |
y_feats = F.normalize(y_feats) | |
loss = (1 - (x_feats * y_feats).sum(dim=1)).mean() | |
return loss | |