File size: 1,942 Bytes
742d952 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import cv2
import torch
import torch.nn as nn
import numpy as np
from insightface.app import FaceAnalysis
from pytorch_msssim import ssim
import Image
class StyleTransferLoss(nn.Module):
def __init__(self, device='cuda', face_analysis = None):
super(StyleTransferLoss, self).__init__()
if face_analysis is None:
self.face_analysis = FaceAnalysis(providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
self.face_analysis.prepare(ctx_id=0, det_size=(128, 128))
else:
self.face_analysis = face_analysis
self.device = device
self.cosine_similarity = nn.CosineSimilarity(dim=0)
# Content loss
self.content_loss = nn.MSELoss()
def extract_face_latent(self, image):
# Convert torch tensor to numpy array
face_tensor = image.squeeze().cpu().detach()
face_np = (face_tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
face_np = cv2.cvtColor(face_np, cv2.COLOR_RGB2BGR)
# Extract face embedding
faces = self.face_analysis.get(face_np)
if len(faces) == 0:
return None
return torch.tensor(Image.getLatent(faces[0])[0]).to(self.device)
def forward(self, output_image, target_content):
# Content loss
# content_loss = self.content_loss(output_image, target_content)
content_loss = 1 - ssim(output_image, target_content, data_range=1.0)
output_embedding = self.extract_face_latent(output_image)
target_embedding = self.extract_face_latent(target_content)
identity_loss = None
if output_embedding is not None and target_embedding is not None:
similarity = self.cosine_similarity(output_embedding, target_embedding)
identity_loss = 1-((similarity + 1) / 2)
identity_loss = identity_loss ** 2 * 10
return content_loss, identity_loss
|