|
"""This script is the test script for Deep3DFaceRecon_pytorch |
|
Pytorch Deep3D_Recon is 8x faster than TF-based, 16s/iter ==> 2s/iter |
|
""" |
|
|
|
import os |
|
|
|
import torch |
|
import torch.nn as nn |
|
from .deep_3drecon_models.facerecon_model import FaceReconModel |
|
from .util.preprocess import align_img |
|
from PIL import Image |
|
import numpy as np |
|
from .util.load_mats import load_lm3d |
|
import torch |
|
import pickle as pkl |
|
from PIL import Image |
|
|
|
from utils.commons.tensor_utils import convert_to_tensor, convert_to_np |
|
|
|
with open("deep_3drecon/reconstructor_opt.pkl", "rb") as f: |
|
opt = pkl.load(f) |
|
|
|
class Reconstructor(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.model = FaceReconModel(opt) |
|
self.model.setup(opt) |
|
self.model.device = 'cuda:0' |
|
self.model.parallelize() |
|
|
|
self.model.eval() |
|
self.lm3d_std = load_lm3d(opt.bfm_folder) |
|
|
|
def preprocess_data(self, im, lm, lm3d_std): |
|
|
|
H,W,_ = im.shape |
|
lm = lm.reshape([-1, 2]) |
|
lm[:, -1] = H - 1 - lm[:, -1] |
|
|
|
_, im, lm, _ = align_img(Image.fromarray(convert_to_np(im)), convert_to_np(lm), convert_to_np(lm3d_std)) |
|
im = torch.tensor(np.array(im)/255., dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) |
|
lm = torch.tensor(lm).unsqueeze(0) |
|
return im, lm |
|
|
|
@torch.no_grad() |
|
def recon_coeff(self, batched_images, batched_lm5, return_image=True, batch_mode=True): |
|
bs = batched_images.shape[0] |
|
data_lst = [] |
|
for i in range(bs): |
|
img = batched_images[i] |
|
lm5 = batched_lm5[i] |
|
align_im, lm = self.preprocess_data(img, lm5, self.lm3d_std) |
|
data = { |
|
'imgs': align_im, |
|
'lms': lm |
|
} |
|
data_lst.append(data) |
|
if not batch_mode: |
|
coeff_lst = [] |
|
align_lst = [] |
|
for i in range(bs): |
|
data = data_lst |
|
self.model.set_input(data) |
|
self.model.forward() |
|
pred_coeff = self.model.output_coeff.cpu().numpy() |
|
align_im = (align_im.squeeze().permute(1,2,0)*255).int().numpy().astype(np.uint8) |
|
coeff_lst.append(pred_coeff) |
|
align_lst.append(align_im) |
|
batch_coeff = np.concatenate(coeff_lst) |
|
batch_align_img = np.stack(align_lst) |
|
else: |
|
imgs = torch.cat([d['imgs'] for d in data_lst]) |
|
lms = torch.cat([d['lms'] for d in data_lst]) |
|
data = { |
|
'imgs': imgs, |
|
'lms': lms |
|
} |
|
self.model.set_input(data) |
|
self.model.forward() |
|
batch_coeff = self.model.output_coeff.cpu().numpy() |
|
batch_align_img = (imgs.permute(0,2,3,1)*255).int().numpy().astype(np.uint8) |
|
return batch_coeff, batch_align_img |
|
|
|
|
|
|
|
def forward(self, batched_images, batched_lm5, return_image=True): |
|
return self.recon_coeff(batched_images, batched_lm5, return_image) |
|
|
|
|
|
|
|
|