facetest / facefusion /inference_manager.py
LULDev's picture
Upload folder using huggingface_hub
a1da63c verified
from functools import lru_cache
from time import sleep
from typing import List
import onnx
from onnxruntime import InferenceSession
from facefusion import process_manager, state_manager
from facefusion.app_context import detect_app_context
from facefusion.execution import create_execution_providers, has_execution_provider
from facefusion.thread_helper import thread_lock
from facefusion.typing import DownloadSet, ExecutionProviderKey, InferencePool, InferencePoolSet, ModelInitializer
INFERENCE_POOLS : InferencePoolSet =\
{
'cli': {}, # type:ignore[typeddict-item]
'ui': {} # type:ignore[typeddict-item]
}
def get_inference_pool(model_context : str, model_sources : DownloadSet) -> InferencePool:
global INFERENCE_POOLS
with thread_lock():
while process_manager.is_checking():
sleep(0.5)
app_context = detect_app_context()
if INFERENCE_POOLS.get(app_context).get(model_context) is None:
INFERENCE_POOLS[app_context][model_context] = create_inference_pool(model_sources, state_manager.get_item('execution_device_id'), find_execution_providers(model_context))
return INFERENCE_POOLS.get(app_context).get(model_context)
def create_inference_pool(model_sources : DownloadSet, execution_device_id : str, execution_provider_keys : List[ExecutionProviderKey]) -> InferencePool:
inference_pool : InferencePool = {}
for model_name in model_sources.keys():
inference_pool[model_name] = create_inference_session(model_sources.get(model_name).get('path'), execution_device_id, execution_provider_keys)
return inference_pool
def clear_inference_pool(model_context : str) -> None:
global INFERENCE_POOLS
app_context = detect_app_context()
INFERENCE_POOLS[app_context][model_context] = None
def create_inference_session(model_path : str, execution_device_id : str, execution_provider_keys : List[ExecutionProviderKey]) -> InferenceSession:
providers = create_execution_providers(execution_device_id, execution_provider_keys)
return InferenceSession(model_path, providers = providers)
@lru_cache(maxsize = None)
def get_static_model_initializer(model_path : str) -> ModelInitializer:
model = onnx.load(model_path)
return onnx.numpy_helper.to_array(model.graph.initializer[-1])
def find_execution_providers(model_context : str) -> List[ExecutionProviderKey]:
if has_execution_provider('coreml'):
if model_context == 'facefusion.frame_colorizer':
return [ 'cpu' ]
return state_manager.get_item('execution_providers')