File size: 3,780 Bytes
a1da63c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import importlib
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from queue import Queue
from types import ModuleType
from typing import Any, List

from tqdm import tqdm

from facefusion import logger, state_manager, wording
from facefusion.exit_helper import hard_exit
from facefusion.typing import ProcessFrames, QueuePayload

PROCESSORS_MODULES : List[ModuleType] = []
PROCESSORS_METHODS =\
[
	'get_inference_pool',
	'clear_inference_pool',
	'register_args',
	'apply_args',
	'pre_check',
	'pre_process',
	'post_process',
	'get_reference_frame',
	'process_frame',
	'process_frames',
	'process_image',
	'process_video'
]


def load_processor_module(processor : str) -> Any:
	try:
		processor_module = importlib.import_module('facefusion.processors.modules.' + processor)
		for method_name in PROCESSORS_METHODS:
			if not hasattr(processor_module, method_name):
				raise NotImplementedError
	except ModuleNotFoundError as exception:
		logger.error(wording.get('processor_not_loaded').format(processor = processor), __name__.upper())
		logger.debug(exception.msg, __name__.upper())
		hard_exit(1)
	except NotImplementedError:
		logger.error(wording.get('processor_not_implemented').format(processor = processor), __name__.upper())
		hard_exit(1)
	return processor_module


def get_processors_modules(processors : List[str]) -> List[ModuleType]:
	global PROCESSORS_MODULES

	if not PROCESSORS_MODULES:
		for processor in processors:
			processor_module = load_processor_module(processor)
			PROCESSORS_MODULES.append(processor_module)
	return PROCESSORS_MODULES


def clear_processors_modules() -> None:
	global PROCESSORS_MODULES

	for processor_module in PROCESSORS_MODULES:
		processor_module.clear_inference_pool()
	PROCESSORS_MODULES = []


def multi_process_frames(source_paths : List[str], temp_frame_paths : List[str], process_frames : ProcessFrames) -> None:
	queue_payloads = create_queue_payloads(temp_frame_paths)
	with tqdm(total = len(queue_payloads), desc = wording.get('processing'), unit = 'frame', ascii = ' =', disable = state_manager.get_item('log_level') in [ 'warn', 'error' ]) as progress:
		progress.set_postfix(
		{
			'execution_providers': state_manager.get_item('execution_providers'),
			'execution_thread_count': state_manager.get_item('execution_thread_count'),
			'execution_queue_count': state_manager.get_item('execution_queue_count')
		})
		with ThreadPoolExecutor(max_workers = state_manager.get_item('execution_thread_count')) as executor:
			futures = []
			queue : Queue[QueuePayload] = create_queue(queue_payloads)
			queue_per_future = max(len(queue_payloads) // state_manager.get_item('execution_thread_count') * state_manager.get_item('execution_queue_count'), 1)

			while not queue.empty():
				future = executor.submit(process_frames, source_paths, pick_queue(queue, queue_per_future), progress.update)
				futures.append(future)

			for future_done in as_completed(futures):
				future_done.result()


def create_queue(queue_payloads : List[QueuePayload]) -> Queue[QueuePayload]:
	queue : Queue[QueuePayload] = Queue()
	for queue_payload in queue_payloads:
		queue.put(queue_payload)
	return queue


def pick_queue(queue : Queue[QueuePayload], queue_per_future : int) -> List[QueuePayload]:
	queues = []
	for _ in range(queue_per_future):
		if not queue.empty():
			queues.append(queue.get())
	return queues


def create_queue_payloads(temp_frame_paths : List[str]) -> List[QueuePayload]:
	queue_payloads = []
	temp_frame_paths = sorted(temp_frame_paths, key = os.path.basename)

	for frame_number, frame_path in enumerate(temp_frame_paths):
		frame_payload : QueuePayload =\
		{
			'frame_number': frame_number,
			'frame_path': frame_path
		}
		queue_payloads.append(frame_payload)
	return queue_payloads