File size: 3,285 Bytes
e34aada
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""This script is the test script for Deep3DFaceRecon_pytorch
Pytorch Deep3D_Recon is 8x faster than TF-based, 16s/iter ==> 2s/iter
"""

import os
# os.environ['PYTHONPATH'] = os.environ['PYTHONPATH'] + ":" + os.path.abspath("deep_3drecon")
import torch
import torch.nn as nn
from .deep_3drecon_models.facerecon_model import FaceReconModel
from .util.preprocess import align_img
from PIL import Image
import numpy as np
from .util.load_mats import load_lm3d
import torch 
import pickle as pkl
from PIL import Image

from utils.commons.tensor_utils import convert_to_tensor, convert_to_np

with open("deep_3drecon/reconstructor_opt.pkl", "rb") as f:
    opt = pkl.load(f) 
    
class Reconstructor(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = FaceReconModel(opt)
        self.model.setup(opt)
        self.model.device = 'cuda:0'
        self.model.parallelize()
        # self.model.to(self.model.device)
        self.model.eval()
        self.lm3d_std = load_lm3d(opt.bfm_folder) 
    
    def preprocess_data(self, im, lm, lm3d_std):
        # to RGB 
        H,W,_ = im.shape
        lm = lm.reshape([-1, 2])
        lm[:, -1] = H - 1 - lm[:, -1]

        _, im, lm, _ = align_img(Image.fromarray(convert_to_np(im)), convert_to_np(lm), convert_to_np(lm3d_std))
        im = torch.tensor(np.array(im)/255., dtype=torch.float32).permute(2, 0, 1).unsqueeze(0)
        lm = torch.tensor(lm).unsqueeze(0)
        return im, lm
    
    @torch.no_grad()
    def recon_coeff(self, batched_images, batched_lm5, return_image=True, batch_mode=True):
        bs = batched_images.shape[0]
        data_lst = []
        for i in range(bs):
            img = batched_images[i]
            lm5 = batched_lm5[i]
            align_im, lm = self.preprocess_data(img, lm5, self.lm3d_std)
            data = {
                'imgs': align_im,
                'lms': lm
            }
            data_lst.append(data)
        if not batch_mode:
            coeff_lst = []
            align_lst = []
            for i in range(bs):
                data = data_lst
                self.model.set_input(data)  # unpack data from data loader
                self.model.forward()
                pred_coeff = self.model.output_coeff.cpu().numpy()
                align_im = (align_im.squeeze().permute(1,2,0)*255).int().numpy().astype(np.uint8)
                coeff_lst.append(pred_coeff)
                align_lst.append(align_im)
            batch_coeff = np.concatenate(coeff_lst)
            batch_align_img = np.stack(align_lst) # [B, 257]
        else:
            imgs = torch.cat([d['imgs'] for d in data_lst])
            lms = torch.cat([d['lms'] for d in data_lst])
            data = {
                'imgs': imgs,
                'lms': lms
            }
            self.model.set_input(data)  # unpack data from data loader
            self.model.forward()
            batch_coeff = self.model.output_coeff.cpu().numpy()
            batch_align_img = (imgs.permute(0,2,3,1)*255).int().numpy().astype(np.uint8)
        return batch_coeff, batch_align_img
    
    # todo: batch-wise recon!
        
    def forward(self, batched_images, batched_lm5, return_image=True):
        return self.recon_coeff(batched_images, batched_lm5, return_image)