|
import numpy as np |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torchvision import transforms |
|
|
|
from typing import Iterable, Union |
|
from pathlib import Path |
|
|
|
|
|
class FaceId(torch.nn.Module): |
|
def __init__( |
|
self, model_path: Path, device: str, input_shape: Iterable[int] = (112, 112) |
|
): |
|
super().__init__() |
|
|
|
self.input_shape = input_shape |
|
self.net = torch.load(model_path, map_location=torch.device("cpu")) |
|
self.net.eval() |
|
|
|
self.transform = transforms.Compose( |
|
[ |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
|
] |
|
) |
|
|
|
for n, p in self.net.named_parameters(): |
|
assert ( |
|
not p.requires_grad |
|
), f"Parameter {n}: requires_grad: {p.requires_grad}" |
|
|
|
self.device = torch.device(device) |
|
self.to(self.device) |
|
|
|
def forward( |
|
self, img_id: Union[np.ndarray, Iterable[np.ndarray]], normalize: bool = True |
|
) -> torch.Tensor: |
|
if isinstance(img_id, Iterable): |
|
img_id = [self.transform(x) for x in img_id] |
|
img_id = torch.stack(img_id, dim=0) |
|
else: |
|
img_id = self.transform(img_id) |
|
img_id = img_id.unsqueeze(0) |
|
|
|
img_id = img_id.to(self.device) |
|
|
|
img_id_112 = F.interpolate(img_id, size=self.input_shape) |
|
latent_id = self.net(img_id_112) |
|
return F.normalize(latent_id, p=2, dim=1) if normalize else latent_id |
|
|