Aadhithya commited on
Commit
75460a1
1 Parent(s): 8249541

Update roop/core.py

Browse files
Files changed (1) hide show
  1. roop/core.py +52 -58
roop/core.py CHANGED
@@ -1,5 +1,4 @@
1
  #!/usr/bin/env python3
2
-
3
  import os
4
  import sys
5
  # single thread doubles cuda performance - needs to be set before torch import
@@ -13,15 +12,20 @@ import platform
13
  import signal
14
  import shutil
15
  import argparse
 
16
  import onnxruntime
17
  import tensorflow
 
18
  import roop.globals
19
  import roop.metadata
20
  import roop.ui as ui
21
- from roop.predictor import predict_image, predict_video
22
  from roop.processors.frame.core import get_frame_processors_modules
23
  from roop.utilities import has_image_extension, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clean_temp, normalize_output_path
24
 
 
 
 
25
  warnings.filterwarnings('ignore', category=FutureWarning, module='insightface')
26
  warnings.filterwarnings('ignore', category=UserWarning, module='torchvision')
27
 
@@ -33,18 +37,13 @@ def parse_args() -> None:
33
  program.add_argument('-t', '--target', help='select an target image or video', dest='target_path')
34
  program.add_argument('-o', '--output', help='select output file or directory', dest='output_path')
35
  program.add_argument('--frame-processor', help='frame processors (choices: face_swapper, face_enhancer, ...)', dest='frame_processor', default=['face_swapper'], nargs='+')
36
- program.add_argument('--keep-fps', help='keep target fps', dest='keep_fps', action='store_true')
37
- program.add_argument('--keep-frames', help='keep temporary frames', dest='keep_frames', action='store_true')
38
- program.add_argument('--skip-audio', help='skip target audio', dest='skip_audio', action='store_true')
39
- program.add_argument('--many-faces', help='process every face', dest='many_faces', action='store_true')
40
- program.add_argument('--reference-face-position', help='position of the reference face', dest='reference_face_position', type=int, default=0)
41
- program.add_argument('--reference-frame-number', help='number of the reference frame', dest='reference_frame_number', type=int, default=0)
42
- program.add_argument('--similar-face-distance', help='face distance used for recognition', dest='similar_face_distance', type=float, default=0.85)
43
- program.add_argument('--temp-frame-format', help='image format used for frame extraction', dest='temp_frame_format', default='png', choices=['jpg', 'png'])
44
- program.add_argument('--temp-frame-quality', help='image quality used for frame extraction', dest='temp_frame_quality', type=int, default=0, choices=range(101), metavar='[0-100]')
45
- program.add_argument('--output-video-encoder', help='encoder used for the output video', dest='output_video_encoder', default='libx264', choices=['libx264', 'libx265', 'libvpx-vp9', 'h264_nvenc', 'hevc_nvenc'])
46
- program.add_argument('--output-video-quality', help='quality used for the output video', dest='output_video_quality', type=int, default=35, choices=range(101), metavar='[0-100]')
47
- program.add_argument('--max-memory', help='maximum amount of RAM in GB', dest='max_memory', type=int)
48
  program.add_argument('--execution-provider', help='available execution provider (choices: cpu, ...)', dest='execution_provider', default=['cpu'], choices=suggest_execution_providers(), nargs='+')
49
  program.add_argument('--execution-threads', help='number of execution threads', dest='execution_threads', type=int, default=suggest_execution_threads())
50
  program.add_argument('-v', '--version', action='version', version=f'{roop.metadata.name} {roop.metadata.version}')
@@ -54,19 +53,14 @@ def parse_args() -> None:
54
  roop.globals.source_path = args.source_path
55
  roop.globals.target_path = args.target_path
56
  roop.globals.output_path = normalize_output_path(roop.globals.source_path, roop.globals.target_path, args.output_path)
57
- roop.globals.headless = roop.globals.source_path is not None and roop.globals.target_path is not None and roop.globals.output_path is not None
58
  roop.globals.frame_processors = args.frame_processor
 
59
  roop.globals.keep_fps = args.keep_fps
 
60
  roop.globals.keep_frames = args.keep_frames
61
- roop.globals.skip_audio = args.skip_audio
62
  roop.globals.many_faces = args.many_faces
63
- roop.globals.reference_face_position = args.reference_face_position
64
- roop.globals.reference_frame_number = args.reference_frame_number
65
- roop.globals.similar_face_distance = args.similar_face_distance
66
- roop.globals.temp_frame_format = args.temp_frame_format
67
- roop.globals.temp_frame_quality = args.temp_frame_quality
68
- roop.globals.output_video_encoder = args.output_video_encoder
69
- roop.globals.output_video_quality = args.output_video_quality
70
  roop.globals.max_memory = args.max_memory
71
  roop.globals.execution_providers = decode_execution_providers(args.execution_provider)
72
  roop.globals.execution_threads = args.execution_threads
@@ -81,14 +75,22 @@ def decode_execution_providers(execution_providers: List[str]) -> List[str]:
81
  if any(execution_provider in encoded_execution_provider for execution_provider in execution_providers)]
82
 
83
 
 
 
 
 
 
 
84
  def suggest_execution_providers() -> List[str]:
85
  return encode_execution_providers(onnxruntime.get_available_providers())
86
 
87
 
88
  def suggest_execution_threads() -> int:
89
- if 'CUDAExecutionProvider' in onnxruntime.get_available_providers():
90
- return 8
91
- return 1
 
 
92
 
93
 
94
  def limit_resources() -> None:
@@ -105,13 +107,18 @@ def limit_resources() -> None:
105
  memory = roop.globals.max_memory * 1024 ** 6
106
  if platform.system().lower() == 'windows':
107
  import ctypes
108
- kernel32 = ctypes.windll.kernel32 # type: ignore[attr-defined]
109
  kernel32.SetProcessWorkingSetSize(-1, ctypes.c_size_t(memory), ctypes.c_size_t(memory))
110
  else:
111
  import resource
112
  resource.setrlimit(resource.RLIMIT_DATA, (memory, memory))
113
 
114
 
 
 
 
 
 
115
  def pre_check() -> bool:
116
  if sys.version_info < (3, 9):
117
  update_status('Python version is not supported - please upgrade to 3.9 or higher.')
@@ -137,12 +144,11 @@ def start() -> None:
137
  if predict_image(roop.globals.target_path):
138
  destroy()
139
  shutil.copy2(roop.globals.target_path, roop.globals.output_path)
140
- # process frame
141
  for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
142
  update_status('Progressing...', frame_processor.NAME)
143
  frame_processor.process_image(roop.globals.source_path, roop.globals.output_path, roop.globals.output_path)
144
  frame_processor.post_process()
145
- # validate image
146
  if is_image(roop.globals.target_path):
147
  update_status('Processing to image succeed!')
148
  else:
@@ -151,48 +157,36 @@ def start() -> None:
151
  # process image to videos
152
  if predict_video(roop.globals.target_path):
153
  destroy()
154
- update_status('Creating temporary resources...')
155
  create_temp(roop.globals.target_path)
156
- # extract frames
157
- if roop.globals.keep_fps:
158
- fps = detect_fps(roop.globals.target_path)
159
- update_status(f'Extracting frames with {fps} FPS...')
160
- extract_frames(roop.globals.target_path, fps)
161
- else:
162
- update_status('Extracting frames with 30 FPS...')
163
- extract_frames(roop.globals.target_path)
164
- # process frame
165
  temp_frame_paths = get_temp_frame_paths(roop.globals.target_path)
166
- if temp_frame_paths:
167
- for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
168
- update_status('Progressing...', frame_processor.NAME)
169
- frame_processor.process_video(roop.globals.source_path, temp_frame_paths)
170
- frame_processor.post_process()
171
- else:
172
- update_status('Frames not found...')
173
- return
174
- # create video
175
  if roop.globals.keep_fps:
 
176
  fps = detect_fps(roop.globals.target_path)
177
- update_status(f'Creating video with {fps} FPS...')
178
  create_video(roop.globals.target_path, fps)
179
  else:
180
- update_status('Creating video with 30 FPS...')
181
  create_video(roop.globals.target_path)
182
  # handle audio
183
- if roop.globals.skip_audio:
184
- move_temp(roop.globals.target_path, roop.globals.output_path)
185
- update_status('Skipping audio...')
186
- else:
187
  if roop.globals.keep_fps:
188
  update_status('Restoring audio...')
189
  else:
190
  update_status('Restoring audio might cause issues as fps are not kept...')
191
  restore_audio(roop.globals.target_path, roop.globals.output_path)
192
- # clean temp
193
- update_status('Cleaning temporary resources...')
 
194
  clean_temp(roop.globals.target_path)
195
- # validate video
196
  if is_video(roop.globals.target_path):
197
  update_status('Processing to video succeed!')
198
  else:
@@ -202,7 +196,7 @@ def start() -> None:
202
  def destroy() -> None:
203
  if roop.globals.target_path:
204
  clean_temp(roop.globals.target_path)
205
- sys.exit()
206
 
207
 
208
  def run() -> None:
 
1
  #!/usr/bin/env python3
 
2
  import os
3
  import sys
4
  # single thread doubles cuda performance - needs to be set before torch import
 
12
  import signal
13
  import shutil
14
  import argparse
15
+ import torch
16
  import onnxruntime
17
  import tensorflow
18
+
19
  import roop.globals
20
  import roop.metadata
21
  import roop.ui as ui
22
+ from roop.predicter import predict_image, predict_video
23
  from roop.processors.frame.core import get_frame_processors_modules
24
  from roop.utilities import has_image_extension, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clean_temp, normalize_output_path
25
 
26
+ if 'ROCMExecutionProvider' in roop.globals.execution_providers:
27
+ del torch
28
+
29
  warnings.filterwarnings('ignore', category=FutureWarning, module='insightface')
30
  warnings.filterwarnings('ignore', category=UserWarning, module='torchvision')
31
 
 
37
  program.add_argument('-t', '--target', help='select an target image or video', dest='target_path')
38
  program.add_argument('-o', '--output', help='select output file or directory', dest='output_path')
39
  program.add_argument('--frame-processor', help='frame processors (choices: face_swapper, face_enhancer, ...)', dest='frame_processor', default=['face_swapper'], nargs='+')
40
+ program.add_argument('--keep-fps', help='keep original fps', dest='keep_fps', action='store_true', default=False)
41
+ program.add_argument('--keep-audio', help='keep original audio', dest='keep_audio', action='store_true', default=True)
42
+ program.add_argument('--keep-frames', help='keep temporary frames', dest='keep_frames', action='store_true', default=False)
43
+ program.add_argument('--many-faces', help='process every face', dest='many_faces', action='store_true', default=False)
44
+ program.add_argument('--video-encoder', help='adjust output video encoder', dest='video_encoder', default='libx264', choices=['libx264', 'libx265', 'libvpx-vp9'])
45
+ program.add_argument('--video-quality', help='adjust output video quality', dest='video_quality', type=int, default=18, choices=range(52), metavar='[0-51]')
46
+ program.add_argument('--max-memory', help='maximum amount of RAM in GB', dest='max_memory', type=int, default=suggest_max_memory())
 
 
 
 
 
47
  program.add_argument('--execution-provider', help='available execution provider (choices: cpu, ...)', dest='execution_provider', default=['cpu'], choices=suggest_execution_providers(), nargs='+')
48
  program.add_argument('--execution-threads', help='number of execution threads', dest='execution_threads', type=int, default=suggest_execution_threads())
49
  program.add_argument('-v', '--version', action='version', version=f'{roop.metadata.name} {roop.metadata.version}')
 
53
  roop.globals.source_path = args.source_path
54
  roop.globals.target_path = args.target_path
55
  roop.globals.output_path = normalize_output_path(roop.globals.source_path, roop.globals.target_path, args.output_path)
 
56
  roop.globals.frame_processors = args.frame_processor
57
+ roop.globals.headless = args.source_path or args.target_path or args.output_path
58
  roop.globals.keep_fps = args.keep_fps
59
+ roop.globals.keep_audio = args.keep_audio
60
  roop.globals.keep_frames = args.keep_frames
 
61
  roop.globals.many_faces = args.many_faces
62
+ roop.globals.video_encoder = args.video_encoder
63
+ roop.globals.video_quality = args.video_quality
 
 
 
 
 
64
  roop.globals.max_memory = args.max_memory
65
  roop.globals.execution_providers = decode_execution_providers(args.execution_provider)
66
  roop.globals.execution_threads = args.execution_threads
 
75
  if any(execution_provider in encoded_execution_provider for execution_provider in execution_providers)]
76
 
77
 
78
+ def suggest_max_memory() -> int:
79
+ if platform.system().lower() == 'darwin':
80
+ return 4
81
+ return 16
82
+
83
+
84
  def suggest_execution_providers() -> List[str]:
85
  return encode_execution_providers(onnxruntime.get_available_providers())
86
 
87
 
88
  def suggest_execution_threads() -> int:
89
+ if 'DmlExecutionProvider' in roop.globals.execution_providers:
90
+ return 1
91
+ if 'ROCMExecutionProvider' in roop.globals.execution_providers:
92
+ return 1
93
+ return 8
94
 
95
 
96
  def limit_resources() -> None:
 
107
  memory = roop.globals.max_memory * 1024 ** 6
108
  if platform.system().lower() == 'windows':
109
  import ctypes
110
+ kernel32 = ctypes.windll.kernel32
111
  kernel32.SetProcessWorkingSetSize(-1, ctypes.c_size_t(memory), ctypes.c_size_t(memory))
112
  else:
113
  import resource
114
  resource.setrlimit(resource.RLIMIT_DATA, (memory, memory))
115
 
116
 
117
+ def release_resources() -> None:
118
+ if 'CUDAExecutionProvider' in roop.globals.execution_providers:
119
+ torch.cuda.empty_cache()
120
+
121
+
122
  def pre_check() -> bool:
123
  if sys.version_info < (3, 9):
124
  update_status('Python version is not supported - please upgrade to 3.9 or higher.')
 
144
  if predict_image(roop.globals.target_path):
145
  destroy()
146
  shutil.copy2(roop.globals.target_path, roop.globals.output_path)
 
147
  for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
148
  update_status('Progressing...', frame_processor.NAME)
149
  frame_processor.process_image(roop.globals.source_path, roop.globals.output_path, roop.globals.output_path)
150
  frame_processor.post_process()
151
+ release_resources()
152
  if is_image(roop.globals.target_path):
153
  update_status('Processing to image succeed!')
154
  else:
 
157
  # process image to videos
158
  if predict_video(roop.globals.target_path):
159
  destroy()
160
+ update_status('Creating temp resources...')
161
  create_temp(roop.globals.target_path)
162
+ update_status('Extracting frames...')
163
+ extract_frames(roop.globals.target_path)
 
 
 
 
 
 
 
164
  temp_frame_paths = get_temp_frame_paths(roop.globals.target_path)
165
+ for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
166
+ update_status('Progressing...', frame_processor.NAME)
167
+ frame_processor.process_video(roop.globals.source_path, temp_frame_paths)
168
+ frame_processor.post_process()
169
+ release_resources()
170
+ # handles fps
 
 
 
171
  if roop.globals.keep_fps:
172
+ update_status('Detecting fps...')
173
  fps = detect_fps(roop.globals.target_path)
174
+ update_status(f'Creating video with {fps} fps...')
175
  create_video(roop.globals.target_path, fps)
176
  else:
177
+ update_status('Creating video with 30.0 fps...')
178
  create_video(roop.globals.target_path)
179
  # handle audio
180
+ if roop.globals.keep_audio:
 
 
 
181
  if roop.globals.keep_fps:
182
  update_status('Restoring audio...')
183
  else:
184
  update_status('Restoring audio might cause issues as fps are not kept...')
185
  restore_audio(roop.globals.target_path, roop.globals.output_path)
186
+ else:
187
+ move_temp(roop.globals.target_path, roop.globals.output_path)
188
+ # clean and validate
189
  clean_temp(roop.globals.target_path)
 
190
  if is_video(roop.globals.target_path):
191
  update_status('Processing to video succeed!')
192
  else:
 
196
  def destroy() -> None:
197
  if roop.globals.target_path:
198
  clean_temp(roop.globals.target_path)
199
+ quit()
200
 
201
 
202
  def run() -> None: