Spaces:
baselqt
/
No application file

File size: 1,523 Bytes
22b8701
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import numpy as np

import torch
import torch.nn.functional as F
from torchvision import transforms

from typing import Iterable, Union
from pathlib import Path


class FaceId(torch.nn.Module):
    def __init__(
        self, model_path: Path, device: str, input_shape: Iterable[int] = (112, 112)
    ):
        super().__init__()

        self.input_shape = input_shape
        self.net = torch.load(model_path, map_location=torch.device("cpu"))
        self.net.eval()

        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ]
        )

        for n, p in self.net.named_parameters():
            assert (
                not p.requires_grad
            ), f"Parameter {n}: requires_grad: {p.requires_grad}"

        self.device = torch.device(device)
        self.to(self.device)

    def forward(
        self, img_id: Union[np.ndarray, Iterable[np.ndarray]], normalize: bool = True
    ) -> torch.Tensor:
        if isinstance(img_id, Iterable):
            img_id = [self.transform(x) for x in img_id]
            img_id = torch.stack(img_id, dim=0)
        else:
            img_id = self.transform(img_id)
            img_id = img_id.unsqueeze(0)

        img_id = img_id.to(self.device)

        img_id_112 = F.interpolate(img_id, size=self.input_shape)
        latent_id = self.net(img_id_112)
        return F.normalize(latent_id, p=2, dim=1) if normalize else latent_id