echen01
fix device
5b7158a
raw
history blame
2.13 kB
import torch
from torch import nn
import torch.nn.functional as F
from criteria.model_irse import Backbone
from criteria.backbones import get_model
class IDLoss(nn.Module):
"""
Computes a cosine similarity between people in two images.
Taken from TreB1eN's [1] implementation of InsightFace [2, 3], as used in pixel2style2pixel [4].
[1] https://github.com/TreB1eN/InsightFace_Pytorch
[2] https://github.com/deepinsight/insightface
[3] Deng, Jiankang and Guo, Jia and Niannan, Xue and Zafeiriou, Stefanos.
ArcFace: Additive Angular Margin Loss for Deep Face Recognition. In CVPR, 2019
[4] https://github.com/eladrich/pixel2style2pixel
"""
def __init__(self, model_path, official=False, device="cpu"):
"""
Arguments:
model_path (str): Path to IR-SE50 model.
"""
super(IDLoss, self).__init__()
print("Loading ResNet ArcFace")
self.official = official
if official:
self.facenet = get_model("r100", fp16=False)
else:
self.facenet = Backbone(
input_size=112, num_layers=50, drop_ratio=0.6, mode="ir_se"
)
self.facenet.load_state_dict(torch.load(model_path, map_location=device))
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
self.facenet.eval()
def extract_feats(self, x):
x = x[:, :, 35:223, 32:220] # Crop interesting region
x = self.face_pool(x)
x_feats = self.facenet(x)
return x_feats
def forward(self, x, y):
"""
Arguments:
x (Tensor): The batch of original images
y (Tensor): The batch of generated images
Returns:
loss (Tensor): Cosine similarity between the
features of the original and generated images.
"""
x_feats = self.extract_feats(x)
y_feats = self.extract_feats(y)
if self.official:
x_feats = F.normalize(x_feats)
y_feats = F.normalize(y_feats)
loss = (1 - (x_feats * y_feats).sum(dim=1)).mean()
return loss