|
|
|
|
|
|
|
|
|
import numpy as np |
|
from model_lib import ModelBase |
|
|
|
MODEL_ZOO = { |
|
'er8_bs1': { |
|
'model_path': 'pretrain_models/9O_865k.onnx', |
|
}, |
|
} |
|
|
|
|
|
class HifiFace(ModelBase): |
|
def __init__(self, model_name='er8_bs1', provider='gpu'): |
|
super().__init__(MODEL_ZOO[model_name], provider) |
|
|
|
def forward(self, src_face_image, dst_face_latent): |
|
""" |
|
Args: |
|
src_face_image: |
|
dst_face_latent: |
|
Returns: |
|
""" |
|
img_tensor = ((src_face_image.transpose(2, 0, 1) / 255.0) * 2 - 1)[None] |
|
blob = [img_tensor.astype(np.float32), dst_face_latent.astype(np.float32)] |
|
output = self.model.forward(blob) |
|
|
|
if self.model_type == 'trt': |
|
mask, swap_face = output |
|
else: |
|
swap_face, mask = output |
|
|
|
return mask, swap_face |
|
|