Update roop/core.py

#1
by Apex-X - opened
Files changed (1) hide show
  1. roop/core.py +54 -59
roop/core.py CHANGED
@@ -1,5 +1,3 @@
1
- #!/usr/bin/env python3
2
-
3
  import os
4
  import sys
5
  # single thread doubles cuda performance - needs to be set before torch import
@@ -13,15 +11,20 @@ import platform
13
  import signal
14
  import shutil
15
  import argparse
 
16
  import onnxruntime
17
  import tensorflow
 
18
  import roop.globals
19
  import roop.metadata
20
  import roop.ui as ui
21
- from roop.predictor import predict_image, predict_video
22
  from roop.processors.frame.core import get_frame_processors_modules
23
  from roop.utilities import has_image_extension, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clean_temp, normalize_output_path
24
 
 
 
 
25
  warnings.filterwarnings('ignore', category=FutureWarning, module='insightface')
26
  warnings.filterwarnings('ignore', category=UserWarning, module='torchvision')
27
 
@@ -33,18 +36,13 @@ def parse_args() -> None:
33
  program.add_argument('-t', '--target', help='select an target image or video', dest='target_path')
34
  program.add_argument('-o', '--output', help='select output file or directory', dest='output_path')
35
  program.add_argument('--frame-processor', help='frame processors (choices: face_swapper, face_enhancer, ...)', dest='frame_processor', default=['face_swapper'], nargs='+')
36
- program.add_argument('--keep-fps', help='keep target fps', dest='keep_fps', action='store_true')
37
- program.add_argument('--keep-frames', help='keep temporary frames', dest='keep_frames', action='store_true')
38
- program.add_argument('--skip-audio', help='skip target audio', dest='skip_audio', action='store_true')
39
- program.add_argument('--many-faces', help='process every face', dest='many_faces', action='store_true')
40
- program.add_argument('--reference-face-position', help='position of the reference face', dest='reference_face_position', type=int, default=0)
41
- program.add_argument('--reference-frame-number', help='number of the reference frame', dest='reference_frame_number', type=int, default=0)
42
- program.add_argument('--similar-face-distance', help='face distance used for recognition', dest='similar_face_distance', type=float, default=0.85)
43
- program.add_argument('--temp-frame-format', help='image format used for frame extraction', dest='temp_frame_format', default='png', choices=['jpg', 'png'])
44
- program.add_argument('--temp-frame-quality', help='image quality used for frame extraction', dest='temp_frame_quality', type=int, default=0, choices=range(101), metavar='[0-100]')
45
- program.add_argument('--output-video-encoder', help='encoder used for the output video', dest='output_video_encoder', default='libx264', choices=['libx264', 'libx265', 'libvpx-vp9', 'h264_nvenc', 'hevc_nvenc'])
46
- program.add_argument('--output-video-quality', help='quality used for the output video', dest='output_video_quality', type=int, default=35, choices=range(101), metavar='[0-100]')
47
- program.add_argument('--max-memory', help='maximum amount of RAM in GB', dest='max_memory', type=int)
48
  program.add_argument('--execution-provider', help='available execution provider (choices: cpu, ...)', dest='execution_provider', default=['cpu'], choices=suggest_execution_providers(), nargs='+')
49
  program.add_argument('--execution-threads', help='number of execution threads', dest='execution_threads', type=int, default=suggest_execution_threads())
50
  program.add_argument('-v', '--version', action='version', version=f'{roop.metadata.name} {roop.metadata.version}')
@@ -54,19 +52,14 @@ def parse_args() -> None:
54
  roop.globals.source_path = args.source_path
55
  roop.globals.target_path = args.target_path
56
  roop.globals.output_path = normalize_output_path(roop.globals.source_path, roop.globals.target_path, args.output_path)
57
- roop.globals.headless = roop.globals.source_path is not None and roop.globals.target_path is not None and roop.globals.output_path is not None
58
  roop.globals.frame_processors = args.frame_processor
 
59
  roop.globals.keep_fps = args.keep_fps
 
60
  roop.globals.keep_frames = args.keep_frames
61
- roop.globals.skip_audio = args.skip_audio
62
  roop.globals.many_faces = args.many_faces
63
- roop.globals.reference_face_position = args.reference_face_position
64
- roop.globals.reference_frame_number = args.reference_frame_number
65
- roop.globals.similar_face_distance = args.similar_face_distance
66
- roop.globals.temp_frame_format = args.temp_frame_format
67
- roop.globals.temp_frame_quality = args.temp_frame_quality
68
- roop.globals.output_video_encoder = args.output_video_encoder
69
- roop.globals.output_video_quality = args.output_video_quality
70
  roop.globals.max_memory = args.max_memory
71
  roop.globals.execution_providers = decode_execution_providers(args.execution_provider)
72
  roop.globals.execution_threads = args.execution_threads
@@ -81,14 +74,22 @@ def decode_execution_providers(execution_providers: List[str]) -> List[str]:
81
  if any(execution_provider in encoded_execution_provider for execution_provider in execution_providers)]
82
 
83
 
 
 
 
 
 
 
84
  def suggest_execution_providers() -> List[str]:
85
  return encode_execution_providers(onnxruntime.get_available_providers())
86
 
87
 
88
  def suggest_execution_threads() -> int:
89
- if 'CUDAExecutionProvider' in onnxruntime.get_available_providers():
90
- return 8
91
- return 1
 
 
92
 
93
 
94
  def limit_resources() -> None:
@@ -105,13 +106,18 @@ def limit_resources() -> None:
105
  memory = roop.globals.max_memory * 1024 ** 6
106
  if platform.system().lower() == 'windows':
107
  import ctypes
108
- kernel32 = ctypes.windll.kernel32 # type: ignore[attr-defined]
109
  kernel32.SetProcessWorkingSetSize(-1, ctypes.c_size_t(memory), ctypes.c_size_t(memory))
110
  else:
111
  import resource
112
  resource.setrlimit(resource.RLIMIT_DATA, (memory, memory))
113
 
114
 
 
 
 
 
 
115
  def pre_check() -> bool:
116
  if sys.version_info < (3, 9):
117
  update_status('Python version is not supported - please upgrade to 3.9 or higher.')
@@ -137,12 +143,11 @@ def start() -> None:
137
  if predict_image(roop.globals.target_path):
138
  destroy()
139
  shutil.copy2(roop.globals.target_path, roop.globals.output_path)
140
- # process frame
141
  for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
142
  update_status('Progressing...', frame_processor.NAME)
143
  frame_processor.process_image(roop.globals.source_path, roop.globals.output_path, roop.globals.output_path)
144
  frame_processor.post_process()
145
- # validate image
146
  if is_image(roop.globals.target_path):
147
  update_status('Processing to image succeed!')
148
  else:
@@ -151,48 +156,36 @@ def start() -> None:
151
  # process image to videos
152
  if predict_video(roop.globals.target_path):
153
  destroy()
154
- update_status('Creating temporary resources...')
155
  create_temp(roop.globals.target_path)
156
- # extract frames
157
- if roop.globals.keep_fps:
158
- fps = detect_fps(roop.globals.target_path)
159
- update_status(f'Extracting frames with {fps} FPS...')
160
- extract_frames(roop.globals.target_path, fps)
161
- else:
162
- update_status('Extracting frames with 30 FPS...')
163
- extract_frames(roop.globals.target_path)
164
- # process frame
165
  temp_frame_paths = get_temp_frame_paths(roop.globals.target_path)
166
- if temp_frame_paths:
167
- for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
168
- update_status('Progressing...', frame_processor.NAME)
169
- frame_processor.process_video(roop.globals.source_path, temp_frame_paths)
170
- frame_processor.post_process()
171
- else:
172
- update_status('Frames not found...')
173
- return
174
- # create video
175
  if roop.globals.keep_fps:
 
176
  fps = detect_fps(roop.globals.target_path)
177
- update_status(f'Creating video with {fps} FPS...')
178
  create_video(roop.globals.target_path, fps)
179
  else:
180
- update_status('Creating video with 30 FPS...')
181
  create_video(roop.globals.target_path)
182
  # handle audio
183
- if roop.globals.skip_audio:
184
- move_temp(roop.globals.target_path, roop.globals.output_path)
185
- update_status('Skipping audio...')
186
- else:
187
  if roop.globals.keep_fps:
188
  update_status('Restoring audio...')
189
  else:
190
  update_status('Restoring audio might cause issues as fps are not kept...')
191
  restore_audio(roop.globals.target_path, roop.globals.output_path)
192
- # clean temp
193
- update_status('Cleaning temporary resources...')
 
194
  clean_temp(roop.globals.target_path)
195
- # validate video
196
  if is_video(roop.globals.target_path):
197
  update_status('Processing to video succeed!')
198
  else:
@@ -202,7 +195,7 @@ def start() -> None:
202
  def destroy() -> None:
203
  if roop.globals.target_path:
204
  clean_temp(roop.globals.target_path)
205
- sys.exit()
206
 
207
 
208
  def run() -> None:
@@ -218,3 +211,5 @@ def run() -> None:
218
  else:
219
  window = ui.init(start, destroy)
220
  window.mainloop()
 
 
 
 
 
1
  import os
2
  import sys
3
  # single thread doubles cuda performance - needs to be set before torch import
 
11
  import signal
12
  import shutil
13
  import argparse
14
+ import torch
15
  import onnxruntime
16
  import tensorflow
17
+
18
  import roop.globals
19
  import roop.metadata
20
  import roop.ui as ui
21
+ from roop.predicter import predict_image, predict_video
22
  from roop.processors.frame.core import get_frame_processors_modules
23
  from roop.utilities import has_image_extension, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clean_temp, normalize_output_path
24
 
25
+ if 'ROCMExecutionProvider' in roop.globals.execution_providers:
26
+ del torch
27
+
28
  warnings.filterwarnings('ignore', category=FutureWarning, module='insightface')
29
  warnings.filterwarnings('ignore', category=UserWarning, module='torchvision')
30
 
 
36
  program.add_argument('-t', '--target', help='select an target image or video', dest='target_path')
37
  program.add_argument('-o', '--output', help='select output file or directory', dest='output_path')
38
  program.add_argument('--frame-processor', help='frame processors (choices: face_swapper, face_enhancer, ...)', dest='frame_processor', default=['face_swapper'], nargs='+')
39
+ program.add_argument('--keep-fps', help='keep original fps', dest='keep_fps', action='store_true', default=False)
40
+ program.add_argument('--keep-audio', help='keep original audio', dest='keep_audio', action='store_true', default=True)
41
+ program.add_argument('--keep-frames', help='keep temporary frames', dest='keep_frames', action='store_true', default=False)
42
+ program.add_argument('--many-faces', help='process every face', dest='many_faces', action='store_true', default=False)
43
+ program.add_argument('--video-encoder', help='adjust output video encoder', dest='video_encoder', default='libx264', choices=['libx264', 'libx265', 'libvpx-vp9'])
44
+ program.add_argument('--video-quality', help='adjust output video quality', dest='video_quality', type=int, default=18, choices=range(52), metavar='[0-51]')
45
+ program.add_argument('--max-memory', help='maximum amount of RAM in GB', dest='max_memory', type=int, default=suggest_max_memory())
 
 
 
 
 
46
  program.add_argument('--execution-provider', help='available execution provider (choices: cpu, ...)', dest='execution_provider', default=['cpu'], choices=suggest_execution_providers(), nargs='+')
47
  program.add_argument('--execution-threads', help='number of execution threads', dest='execution_threads', type=int, default=suggest_execution_threads())
48
  program.add_argument('-v', '--version', action='version', version=f'{roop.metadata.name} {roop.metadata.version}')
 
52
  roop.globals.source_path = args.source_path
53
  roop.globals.target_path = args.target_path
54
  roop.globals.output_path = normalize_output_path(roop.globals.source_path, roop.globals.target_path, args.output_path)
 
55
  roop.globals.frame_processors = args.frame_processor
56
+ roop.globals.headless = args.source_path or args.target_path or args.output_path
57
  roop.globals.keep_fps = args.keep_fps
58
+ roop.globals.keep_audio = args.keep_audio
59
  roop.globals.keep_frames = args.keep_frames
 
60
  roop.globals.many_faces = args.many_faces
61
+ roop.globals.video_encoder = args.video_encoder
62
+ roop.globals.video_quality = args.video_quality
 
 
 
 
 
63
  roop.globals.max_memory = args.max_memory
64
  roop.globals.execution_providers = decode_execution_providers(args.execution_provider)
65
  roop.globals.execution_threads = args.execution_threads
 
74
  if any(execution_provider in encoded_execution_provider for execution_provider in execution_providers)]
75
 
76
 
77
+ def suggest_max_memory() -> int:
78
+ if platform.system().lower() == 'darwin':
79
+ return 4
80
+ return 16
81
+
82
+
83
  def suggest_execution_providers() -> List[str]:
84
  return encode_execution_providers(onnxruntime.get_available_providers())
85
 
86
 
87
  def suggest_execution_threads() -> int:
88
+ if 'DmlExecutionProvider' in roop.globals.execution_providers:
89
+ return 1
90
+ if 'ROCMExecutionProvider' in roop.globals.execution_providers:
91
+ return 1
92
+ return 8
93
 
94
 
95
  def limit_resources() -> None:
 
106
  memory = roop.globals.max_memory * 1024 ** 6
107
  if platform.system().lower() == 'windows':
108
  import ctypes
109
+ kernel32 = ctypes.windll.kernel32
110
  kernel32.SetProcessWorkingSetSize(-1, ctypes.c_size_t(memory), ctypes.c_size_t(memory))
111
  else:
112
  import resource
113
  resource.setrlimit(resource.RLIMIT_DATA, (memory, memory))
114
 
115
 
116
+ def release_resources() -> None:
117
+ if 'CUDAExecutionProvider' in roop.globals.execution_providers:
118
+ torch.cuda.empty_cache()
119
+
120
+
121
  def pre_check() -> bool:
122
  if sys.version_info < (3, 9):
123
  update_status('Python version is not supported - please upgrade to 3.9 or higher.')
 
143
  if predict_image(roop.globals.target_path):
144
  destroy()
145
  shutil.copy2(roop.globals.target_path, roop.globals.output_path)
 
146
  for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
147
  update_status('Progressing...', frame_processor.NAME)
148
  frame_processor.process_image(roop.globals.source_path, roop.globals.output_path, roop.globals.output_path)
149
  frame_processor.post_process()
150
+ release_resources()
151
  if is_image(roop.globals.target_path):
152
  update_status('Processing to image succeed!')
153
  else:
 
156
  # process image to videos
157
  if predict_video(roop.globals.target_path):
158
  destroy()
159
+ update_status('Creating temp resources...')
160
  create_temp(roop.globals.target_path)
161
+ update_status('Extracting frames...')
162
+ extract_frames(roop.globals.target_path)
 
 
 
 
 
 
 
163
  temp_frame_paths = get_temp_frame_paths(roop.globals.target_path)
164
+ for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
165
+ update_status('Progressing...', frame_processor.NAME)
166
+ frame_processor.process_video(roop.globals.source_path, temp_frame_paths)
167
+ frame_processor.post_process()
168
+ release_resources()
169
+ # handles fps
 
 
 
170
  if roop.globals.keep_fps:
171
+ update_status('Detecting fps...')
172
  fps = detect_fps(roop.globals.target_path)
173
+ update_status(f'Creating video with {fps} fps...')
174
  create_video(roop.globals.target_path, fps)
175
  else:
176
+ update_status('Creating video with 30.0 fps...')
177
  create_video(roop.globals.target_path)
178
  # handle audio
179
+ if roop.globals.keep_audio:
 
 
 
180
  if roop.globals.keep_fps:
181
  update_status('Restoring audio...')
182
  else:
183
  update_status('Restoring audio might cause issues as fps are not kept...')
184
  restore_audio(roop.globals.target_path, roop.globals.output_path)
185
+ else:
186
+ move_temp(roop.globals.target_path, roop.globals.output_path)
187
+ # clean and validate
188
  clean_temp(roop.globals.target_path)
 
189
  if is_video(roop.globals.target_path):
190
  update_status('Processing to video succeed!')
191
  else:
 
195
  def destroy() -> None:
196
  if roop.globals.target_path:
197
  clean_temp(roop.globals.target_path)
198
+ quit()
199
 
200
 
201
  def run() -> None:
 
211
  else:
212
  window = ui.init(start, destroy)
213
  window.mainloop()
214
+
215
+