File size: 2,186 Bytes
d9778ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from argparse import Namespace
import time
import torch
import torchvision.transforms as transforms
import dlib
import numpy as np
from PIL import Image

from pixel2style2pixel.utils.common import tensor2im
from pixel2style2pixel.models.psp import pSp
from pixel2style2pixel.scripts.align_all_parallel import align_face


class InversionModel:
    def __init__(self, checkpoint_path: str, dlib_path: str) -> None:
        self.dlib_path = dlib_path
        self.dlib_predictor = dlib.shape_predictor(dlib_path)

        self.tranform_image = transforms.Compose(
            [
                transforms.Resize((256, 256)),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
            ]
        )
        ckpt = torch.load(checkpoint_path, map_location="cpu")
        opts = ckpt["opts"]
        opts["checkpoint_path"] = checkpoint_path
        opts["learn_in_w"] = False
        opts["output_size"] = 1024

        self.opts = Namespace(**opts)
        self.net = pSp(self.opts)
        self.net.eval()
        self.net.cuda()
        print("Model successfully loaded!")

    def run_alignment(self, image_path: str):
        aligned_image = align_face(filepath=image_path, predictor=self.dlib_predictor)
        print("Aligned image has shape: {}".format(aligned_image.size))
        return aligned_image

    def inference(self, image_path: str):
        input_image = self.run_alignment(image_path)
        input_image = input_image.resize((256, 256))
        transformed_image = self.tranform_image(input_image)

        with torch.no_grad():
            tic = time.time()
            result_image, latents = self.net(
                transformed_image.unsqueeze(0).to("cuda").float(),
                return_latents=True,
                randomize_noise=False,
            )
            toc = time.time()
            print("Inference took {:.4f} seconds.".format(toc - tic))

        res_image = tensor2im(result_image[0])
        return (
            res_image,
            {
                "w1": latents.cpu().detach().numpy(),
                "w1_initial": latents.cpu().detach().numpy(),
            },
        )