Arrcttacsrks commited on
Commit
a1bd7ff
·
verified ·
1 Parent(s): b18f1ec

Update roop/processors/frame/face_enhancer.py

Browse files
roop/processors/frame/face_enhancer.py CHANGED
@@ -2,6 +2,9 @@ from typing import Any, List, Callable
2
  import cv2
3
  import threading
4
  import gfpgan
 
 
 
5
 
6
  import roop.globals
7
  import roop.processors.frame.core
@@ -9,7 +12,6 @@ from roop.core import update_status
9
  from roop.face_analyser import get_one_face
10
  from roop.typing import Frame, Face
11
  from roop.utilities import conditional_download, resolve_relative_path, is_image, is_video
12
- import torch
13
 
14
  FACE_ENHANCER = None
15
  THREAD_SEMAPHORE = threading.Semaphore()
@@ -17,27 +19,27 @@ THREAD_LOCK = threading.Lock()
17
  NAME = 'ROOP.FACE-ENHANCER'
18
  frame_name = 'face_enhancer'
19
 
 
 
 
20
  if torch.cuda.is_available():
21
- device='cuda'
22
  else:
23
- device='cpu'
24
-
25
 
26
  def get_face_enhancer() -> Any:
27
  global FACE_ENHANCER
28
 
29
  with THREAD_LOCK:
30
  if FACE_ENHANCER is None:
31
- model_path = resolve_relative_path('../models/GFPGANv1.4.pth')
32
- # todo: set models path https://github.com/TencentARC/GFPGAN/issues/399
33
- FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=1,device=device) # type: ignore[attr-defined]
34
  return FACE_ENHANCER
35
 
36
 
37
  def pre_check() -> bool:
38
- download_directory_path = resolve_relative_path('../models')
39
- # conditional_download(download_directory_path, ['https://huggingface.co/henryruhs/roop/resolve/main/GFPGANv1.4.pth'])
40
- conditional_download(download_directory_path, ['https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth'])
41
  return True
42
 
43
 
 
2
  import cv2
3
  import threading
4
  import gfpgan
5
+ import os
6
+ from huggingface_hub import hf_hub_download # Thêm import này
7
+ import torch
8
 
9
  import roop.globals
10
  import roop.processors.frame.core
 
12
  from roop.face_analyser import get_one_face
13
  from roop.typing import Frame, Face
14
  from roop.utilities import conditional_download, resolve_relative_path, is_image, is_video
 
15
 
16
  FACE_ENHANCER = None
17
  THREAD_SEMAPHORE = threading.Semaphore()
 
19
  NAME = 'ROOP.FACE-ENHANCER'
20
  frame_name = 'face_enhancer'
21
 
22
+ # Lấy token từ biến môi trường
23
+ token = os.getenv('HF_TOKEN')
24
+
25
  if torch.cuda.is_available():
26
+ device = 'cuda'
27
  else:
28
+ device = 'cpu'
 
29
 
30
  def get_face_enhancer() -> Any:
31
  global FACE_ENHANCER
32
 
33
  with THREAD_LOCK:
34
  if FACE_ENHANCER is None:
35
+ # Tải model từ Hugging Face Hub
36
+ model_path = hf_hub_download(repo_id="Arrcttacsrks/TencentARC_GFPGAN", filename="GFPGANv1.4.pth", token=token)
37
+ FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=1, device=device) # type: ignore[attr-defined]
38
  return FACE_ENHANCER
39
 
40
 
41
  def pre_check() -> bool:
42
+ # Không cần điều kiện download nữa vì đã tải mô hình trực tiếp từ Hugging Face Hub
 
 
43
  return True
44
 
45