Spaces:
baselqt
/
No application file

simswap55 / src /FaceId /faceid.py
LB5's picture
Upload 45 files
22b8701
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
from typing import Iterable, Union
from pathlib import Path
class FaceId(torch.nn.Module):
def __init__(
self, model_path: Path, device: str, input_shape: Iterable[int] = (112, 112)
):
super().__init__()
self.input_shape = input_shape
self.net = torch.load(model_path, map_location=torch.device("cpu"))
self.net.eval()
self.transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
for n, p in self.net.named_parameters():
assert (
not p.requires_grad
), f"Parameter {n}: requires_grad: {p.requires_grad}"
self.device = torch.device(device)
self.to(self.device)
def forward(
self, img_id: Union[np.ndarray, Iterable[np.ndarray]], normalize: bool = True
) -> torch.Tensor:
if isinstance(img_id, Iterable):
img_id = [self.transform(x) for x in img_id]
img_id = torch.stack(img_id, dim=0)
else:
img_id = self.transform(img_id)
img_id = img_id.unsqueeze(0)
img_id = img_id.to(self.device)
img_id_112 = F.interpolate(img_id, size=self.input_shape)
latent_id = self.net(img_id_112)
return F.normalize(latent_id, p=2, dim=1) if normalize else latent_id