ReHiFace-S / face_lib /face_swap /hififace_api.py
GuijiAI's picture
Upload 117 files
89cf463 verified
# -- coding: utf-8 --
# @Time : 2022/8/25
# @Author : ykk648
# @Project : https://github.com/ykk648/AI_power
import numpy as np
from model_lib import ModelBase
MODEL_ZOO = {
'er8_bs1': {
'model_path': 'pretrain_models/9O_865k.onnx',
},
}
class HifiFace(ModelBase):
def __init__(self, model_name='er8_bs1', provider='gpu'):
super().__init__(MODEL_ZOO[model_name], provider)
def forward(self, src_face_image, dst_face_latent):
"""
Args:
src_face_image:
dst_face_latent:
Returns:
"""
img_tensor = ((src_face_image.transpose(2, 0, 1) / 255.0) * 2 - 1)[None]
blob = [img_tensor.astype(np.float32), dst_face_latent.astype(np.float32)]
output = self.model.forward(blob)
# print("-------------model_type:",self.model_type)
if self.model_type == 'trt':
mask, swap_face = output
else:
swap_face, mask = output
return mask, swap_face