Arrcttacsrks commited on
Commit
417d5cd
·
verified ·
1 Parent(s): 82091f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -11
app.py CHANGED
@@ -4,19 +4,20 @@ import torchvision.transforms as transforms
4
  from pathlib import Path
5
  import os
6
 
7
- # Cấu hình
8
- MODEL_PATH = "./models"
 
9
  TARGET_EPOCH = 47
10
 
11
- # Hàm tải mô hình
12
  def load_model(model_name, size):
13
- model_directory = Path(MODEL_PATH) / model_name / str(size)
14
  if not model_directory.exists():
15
- raise FileNotFoundError(f"Không tìm thấy thư mục model: {model_directory}")
16
-
17
- # Tải mô hình với trọng số
18
- generator = torch.load(model_directory / f"generator_{TARGET_EPOCH}.pth")
19
- mapping_network = torch.load(model_directory / f"mapping_network_{TARGET_EPOCH}.pth")
20
  return generator, mapping_network
21
 
22
  # Hàm sinh ảnh
@@ -38,10 +39,10 @@ def generate_image(model_name, size, alpha=0.5):
38
 
39
  # Lấy danh sách mô hình và kích thước
40
  def get_model_names():
41
- return [folder.name for folder in Path(MODEL_PATH).glob("*") if folder.is_dir()]
42
 
43
  def get_sizes(model_name):
44
- sizes = [folder.name for folder in (Path(MODEL_PATH) / model_name).iterdir() if folder.is_dir()]
45
  return sizes
46
 
47
  # Tạo giao diện Gradio
 
4
  from pathlib import Path
5
  import os
6
 
7
+ # Lấy Hugging Face token từ biến môi trường
8
+ HF_TOKEN = os.getenv("HF_TOKEN") # Token từ biến môi trường
9
+ MODEL_REPO = "Arrcttacsrks/FaceStyleGan"
10
  TARGET_EPOCH = 47
11
 
12
+ # Hàm tải mô hình từ thư mục local
13
  def load_model(model_name, size):
14
+ model_directory = Path(f"./models/{model_name}/{size}")
15
  if not model_directory.exists():
16
+ raise FileNotFoundError(f"Không tìm thấy thư mục mô hình: {model_directory}")
17
+
18
+ # Tải trọng số mô hình
19
+ mapping_network = torch.load(model_directory / f"mapping_network_{TARGET_EPOCH}.pth", map_location="cpu")
20
+ generator = torch.load(model_directory / f"generator_{TARGET_EPOCH}.pth", map_location="cpu")
21
  return generator, mapping_network
22
 
23
  # Hàm sinh ảnh
 
39
 
40
  # Lấy danh sách mô hình và kích thước
41
  def get_model_names():
42
+ return [folder.name for folder in Path("./models").glob("*") if folder.is_dir()]
43
 
44
  def get_sizes(model_name):
45
+ sizes = [folder.name for folder in (Path("./models") / model_name).iterdir() if folder.is_dir()]
46
  return sizes
47
 
48
  # Tạo giao diện Gradio