Apex-X commited on
Commit
f65c624
1 Parent(s): eb31f77

Upload face_enhancer (2).py

Browse files
roop/processors/frame/face_enhancer (2).py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Callable
2
+ import cv2
3
+ import threading
4
+ import gfpgan
5
+
6
+ import roop.globals
7
+ import roop.processors.frame.core
8
+ from roop.core import update_status
9
+ from roop.face_analyser import get_one_face
10
+ from roop.typing import Frame, Face
11
+ from roop.utilities import conditional_download, resolve_relative_path, is_image, is_video
12
+
13
+ FACE_ENHANCER = None
14
+ THREAD_SEMAPHORE = threading.Semaphore()
15
+ THREAD_LOCK = threading.Lock()
16
+ NAME = 'ROOP.FACE-ENHANCER'
17
+
18
+
19
+ def get_face_enhancer() -> Any:
20
+ global FACE_ENHANCER
21
+
22
+ with THREAD_LOCK:
23
+ if FACE_ENHANCER is None:
24
+ model_path = resolve_relative_path('../models/GFPGANv1.4.pth')
25
+ # todo: set models path https://github.com/TencentARC/GFPGAN/issues/399
26
+ FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=1) # type: ignore[attr-defined]
27
+ return FACE_ENHANCER
28
+
29
+
30
+ def pre_check() -> bool:
31
+ download_directory_path = resolve_relative_path('../models')
32
+ conditional_download(download_directory_path, ['https://huggingface.co/henryruhs/roop/resolve/main/GFPGANv1.4.pth'])
33
+ return True
34
+
35
+
36
+ def pre_start() -> bool:
37
+ if not is_image(roop.globals.target_path) and not is_video(roop.globals.target_path):
38
+ update_status('Select an image or video for target path.', NAME)
39
+ return False
40
+ return True
41
+
42
+
43
+ def post_process() -> None:
44
+ global FACE_ENHANCER
45
+
46
+ FACE_ENHANCER = None
47
+
48
+
49
+ def enhance_face(temp_frame: Frame) -> Frame:
50
+ with THREAD_SEMAPHORE:
51
+ _, _, temp_frame = get_face_enhancer().enhance(
52
+ temp_frame,
53
+ paste_back=True
54
+ )
55
+ return temp_frame
56
+
57
+
58
+ def process_frame(source_face: Face, temp_frame: Frame) -> Frame:
59
+ target_face = get_one_face(temp_frame)
60
+ if target_face:
61
+ temp_frame = enhance_face(temp_frame)
62
+ return temp_frame
63
+
64
+
65
+ def process_frames(source_path: str, temp_frame_paths: List[str], update: Callable[[], None]) -> None:
66
+ for temp_frame_path in temp_frame_paths:
67
+ temp_frame = cv2.imread(temp_frame_path)
68
+ result = process_frame(None, temp_frame)
69
+ cv2.imwrite(temp_frame_path, result)
70
+ if update:
71
+ update()
72
+
73
+
74
+ def process_image(source_path: str, target_path: str, output_path: str) -> None:
75
+ target_frame = cv2.imread(target_path)
76
+ result = process_frame(None, target_frame)
77
+ cv2.imwrite(output_path, result)
78
+
79
+
80
+ def process_video(source_path: str, temp_frame_paths: List[str]) -> None:
81
+ roop.processors.frame.core.process_video(None, temp_frame_paths, process_frames)