Spaces:
Configuration error
Configuration error
File size: 7,202 Bytes
b4af39b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import os.path as osp
import glob
import logging
import insightface
from insightface.model_zoo.model_zoo import ModelRouter, PickableInferenceSession
from insightface.model_zoo.retinaface import RetinaFace
from insightface.model_zoo.landmark import Landmark
from insightface.model_zoo.attribute import Attribute
from insightface.model_zoo.inswapper import INSwapper
from insightface.model_zoo.arcface_onnx import ArcFaceONNX
from insightface.app import FaceAnalysis
from insightface.utils import DEFAULT_MP_NAME, ensure_available
from insightface.model_zoo import model_zoo
import onnxruntime
import onnx
from onnx import numpy_helper
from scripts.reactor_logger import logger
def patched_get_model_log(self, **kwargs):
session = PickableInferenceSession(self.onnx_file, **kwargs)
print(f'Applied providers: {session._providers}, with options: {session._provider_options}')
inputs = session.get_inputs()
input_cfg = inputs[0]
input_shape = input_cfg.shape
outputs = session.get_outputs()
if len(outputs) >= 5:
return RetinaFace(model_file=self.onnx_file, session=session)
elif input_shape[2] == 192 and input_shape[3] == 192:
return Landmark(model_file=self.onnx_file, session=session)
elif input_shape[2] == 96 and input_shape[3] == 96:
return Attribute(model_file=self.onnx_file, session=session)
elif len(inputs) == 2 and input_shape[2] == 128 and input_shape[3] == 128:
return INSwapper(model_file=self.onnx_file, session=session)
elif len(inputs) == 2 and input_shape[2] == 256 and input_shape[3] == 256:
return INSwapper(model_file=self.onnx_file, session=session)
elif input_shape[2] == input_shape[3] and input_shape[2] >= 112 and input_shape[2] % 16 == 0:
return ArcFaceONNX(model_file=self.onnx_file, session=session)
else:
return None
def patched_get_model(self, **kwargs):
session = PickableInferenceSession(self.onnx_file, **kwargs)
inputs = session.get_inputs()
input_cfg = inputs[0]
input_shape = input_cfg.shape
outputs = session.get_outputs()
if len(outputs) >= 5:
return RetinaFace(model_file=self.onnx_file, session=session)
elif input_shape[2] == 192 and input_shape[3] == 192:
return Landmark(model_file=self.onnx_file, session=session)
elif input_shape[2] == 96 and input_shape[3] == 96:
return Attribute(model_file=self.onnx_file, session=session)
elif len(inputs) == 2 and input_shape[2] == 128 and input_shape[3] == 128:
return INSwapper(model_file=self.onnx_file, session=session)
elif len(inputs) == 2 and input_shape[2] == 256 and input_shape[3] == 256:
return INSwapper(model_file=self.onnx_file, session=session)
elif input_shape[2] == input_shape[3] and input_shape[2] >= 112 and input_shape[2] % 16 == 0:
return ArcFaceONNX(model_file=self.onnx_file, session=session)
else:
return None
def patched_faceanalysis_init(self, name=DEFAULT_MP_NAME, root='~/.insightface', allowed_modules=None, **kwargs):
onnxruntime.set_default_logger_severity(3)
self.models = {}
self.model_dir = ensure_available('models', name, root=root)
onnx_files = glob.glob(osp.join(self.model_dir, '*.onnx'))
onnx_files = sorted(onnx_files)
for onnx_file in onnx_files:
model = model_zoo.get_model(onnx_file, **kwargs)
if model is None:
print('model not recognized:', onnx_file)
elif allowed_modules is not None and model.taskname not in allowed_modules:
print('model ignore:', onnx_file, model.taskname)
del model
elif model.taskname not in self.models and (allowed_modules is None or model.taskname in allowed_modules):
self.models[model.taskname] = model
else:
print('duplicated model task type, ignore:', onnx_file, model.taskname)
del model
assert 'detection' in self.models
self.det_model = self.models['detection']
def patched_faceanalysis_prepare(self, ctx_id, det_thresh=0.5, det_size=(640, 640)):
self.det_thresh = det_thresh
assert det_size is not None
self.det_size = det_size
for taskname, model in self.models.items():
if taskname == 'detection':
model.prepare(ctx_id, input_size=det_size, det_thresh=det_thresh)
else:
model.prepare(ctx_id)
def patched_inswapper_init(self, model_file=None, session=None):
self.model_file = model_file
self.session = session
model = onnx.load(self.model_file)
graph = model.graph
self.emap = numpy_helper.to_array(graph.initializer[-1])
self.input_mean = 0.0
self.input_std = 255.0
if self.session is None:
self.session = onnxruntime.InferenceSession(self.model_file, None)
inputs = self.session.get_inputs()
self.input_names = []
for inp in inputs:
self.input_names.append(inp.name)
outputs = self.session.get_outputs()
output_names = []
for out in outputs:
output_names.append(out.name)
self.output_names = output_names
assert len(self.output_names) == 1
input_cfg = inputs[0]
input_shape = input_cfg.shape
self.input_shape = input_shape
self.input_size = tuple(input_shape[2:4][::-1])
def pathced_retinaface_prepare(self, ctx_id, **kwargs):
if ctx_id<0:
self.session.set_providers(['CPUExecutionProvider'])
nms_thresh = kwargs.get('nms_thresh', None)
if nms_thresh is not None:
self.nms_thresh = nms_thresh
det_thresh = kwargs.get('det_thresh', None)
if det_thresh is not None:
self.det_thresh = det_thresh
input_size = kwargs.get('input_size', None)
if input_size is not None and self.input_size is None:
self.input_size = input_size
def patch_insightface(get_model, faceanalysis_init, faceanalysis_prepare, inswapper_init, retinaface_prepare):
insightface.model_zoo.model_zoo.ModelRouter.get_model = get_model
insightface.app.FaceAnalysis.__init__ = faceanalysis_init
insightface.app.FaceAnalysis.prepare = faceanalysis_prepare
insightface.model_zoo.inswapper.INSwapper.__init__ = inswapper_init
insightface.model_zoo.retinaface.RetinaFace.prepare = retinaface_prepare
# original_functions = [ModelRouter.get_model, FaceAnalysis.__init__, FaceAnalysis.prepare, INSwapper.__init__, RetinaFace.prepare]
original_functions = [patched_get_model_log, FaceAnalysis.__init__, FaceAnalysis.prepare, INSwapper.__init__, RetinaFace.prepare]
patched_functions = [patched_get_model, patched_faceanalysis_init, patched_faceanalysis_prepare, patched_inswapper_init, pathced_retinaface_prepare]
def apply_patch(console_log_level):
if console_log_level == 0:
patch_insightface(*patched_functions)
logger.setLevel(logging.WARNING)
elif console_log_level == 1:
patch_insightface(*patched_functions)
logger.setLevel(logging.STATUS)
elif console_log_level == 2:
patch_insightface(*original_functions)
logger.setLevel(logging.INFO)
|