Apex-X commited on
Commit
eb31f77
1 Parent(s): 5b3381b

face_enhancer.py

Browse files
roop_processors_frame_face_enhancer.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Callable
2
+ import cv2
3
+ import threading
4
+ import gfpgan
5
+
6
+ import roop.globals
7
+ import roop.processors.frame.core
8
+ from roop.core import update_status
9
+ from roop.face_analyser import get_one_face
10
+ from roop.typing import Frame, Face
11
+ from roop.utilities import conditional_download, resolve_relative_path, is_image, is_video
12
+ import torch
13
+
14
+ FACE_ENHANCER = None
15
+ THREAD_SEMAPHORE = threading.Semaphore()
16
+ THREAD_LOCK = threading.Lock()
17
+ NAME = 'ROOP.FACE-ENHANCER'
18
+ frame_name = 'face_enhancer'
19
+
20
+ if torch.cuda.is_available():
21
+ device='cuda'
22
+ else:
23
+ device='cpu'
24
+
25
+
26
+ def get_face_enhancer() -> Any:
27
+ global FACE_ENHANCER
28
+
29
+ with THREAD_LOCK:
30
+ if FACE_ENHANCER is None:
31
+ model_path = resolve_relative_path('../models/GFPGANv1.4.pth')
32
+ # todo: set models path https://github.com/TencentARC/GFPGAN/issues/399
33
+ FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=1,device=device) # type: ignore[attr-defined]
34
+ return FACE_ENHANCER
35
+
36
+
37
+ def pre_check() -> bool:
38
+ download_directory_path = resolve_relative_path('../models')
39
+ # conditional_download(download_directory_path, ['https://huggingface.co/henryruhs/roop/resolve/main/GFPGANv1.4.pth'])
40
+ conditional_download(download_directory_path, ['https://huggingface.co/Apex-X/gfpgan.pth/resolve/main/GFPGANv1.4.pth'])
41
+ return True
42
+
43
+
44
+ def pre_start() -> bool:
45
+ if not is_image(roop.globals.target_path) and not is_video(roop.globals.target_path):
46
+ update_status('Select an image or video for target path.', NAME)
47
+ return False
48
+ return True
49
+
50
+
51
+ def post_process() -> None:
52
+ global FACE_ENHANCER
53
+
54
+ FACE_ENHANCER = None
55
+
56
+
57
+ def enhance_face(temp_frame: Frame) -> Frame:
58
+ with THREAD_SEMAPHORE:
59
+ _, _, temp_frame = get_face_enhancer().enhance(
60
+ temp_frame,
61
+ paste_back=True
62
+ )
63
+ return temp_frame
64
+
65
+
66
+ def process_frame(source_face: Face, temp_frame: Frame) -> Frame:
67
+ target_face = get_one_face(temp_frame)
68
+ if target_face:
69
+ temp_frame = enhance_face(temp_frame)
70
+ return temp_frame
71
+
72
+
73
+ def process_frames(source_path: str, temp_frame_paths: List[str], update: Callable[[], None]) -> None:
74
+ for temp_frame_path in temp_frame_paths:
75
+ temp_frame = cv2.imread(temp_frame_path)
76
+ result = process_frame(None, temp_frame)
77
+ cv2.imwrite(temp_frame_path, result)
78
+ if update:
79
+ update()
80
+
81
+
82
+ def process_image(source_path: str, target_path: str, output_path: str) -> None:
83
+ target_frame = cv2.imread(target_path)
84
+ result = process_frame(None, target_frame)
85
+ cv2.imwrite(output_path, result)
86
+
87
+
88
+ def process_video(source_path: str, temp_frame_paths: List[str]) -> None:
89
+ roop.processors.frame.core.process_video(None, temp_frame_paths, process_frames)