Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -4,19 +4,20 @@ import torchvision.transforms as transforms
|
|
4 |
from pathlib import Path
|
5 |
import os
|
6 |
|
7 |
-
#
|
8 |
-
|
|
|
9 |
TARGET_EPOCH = 47
|
10 |
|
11 |
-
# Hàm tải mô hình
|
12 |
def load_model(model_name, size):
|
13 |
-
model_directory = Path(
|
14 |
if not model_directory.exists():
|
15 |
-
raise FileNotFoundError(f"Không tìm thấy thư mục
|
16 |
-
|
17 |
-
# Tải mô hình
|
18 |
-
|
19 |
-
|
20 |
return generator, mapping_network
|
21 |
|
22 |
# Hàm sinh ảnh
|
@@ -38,10 +39,10 @@ def generate_image(model_name, size, alpha=0.5):
|
|
38 |
|
39 |
# Lấy danh sách mô hình và kích thước
|
40 |
def get_model_names():
|
41 |
-
return [folder.name for folder in Path(
|
42 |
|
43 |
def get_sizes(model_name):
|
44 |
-
sizes = [folder.name for folder in (Path(
|
45 |
return sizes
|
46 |
|
47 |
# Tạo giao diện Gradio
|
|
|
4 |
from pathlib import Path
|
5 |
import os
|
6 |
|
7 |
+
# Lấy Hugging Face token từ biến môi trường
|
8 |
+
HF_TOKEN = os.getenv("HF_TOKEN") # Token từ biến môi trường
|
9 |
+
MODEL_REPO = "Arrcttacsrks/FaceStyleGan"
|
10 |
TARGET_EPOCH = 47
|
11 |
|
12 |
+
# Hàm tải mô hình từ thư mục local
|
13 |
def load_model(model_name, size):
|
14 |
+
model_directory = Path(f"./models/{model_name}/{size}")
|
15 |
if not model_directory.exists():
|
16 |
+
raise FileNotFoundError(f"Không tìm thấy thư mục mô hình: {model_directory}")
|
17 |
+
|
18 |
+
# Tải trọng số mô hình
|
19 |
+
mapping_network = torch.load(model_directory / f"mapping_network_{TARGET_EPOCH}.pth", map_location="cpu")
|
20 |
+
generator = torch.load(model_directory / f"generator_{TARGET_EPOCH}.pth", map_location="cpu")
|
21 |
return generator, mapping_network
|
22 |
|
23 |
# Hàm sinh ảnh
|
|
|
39 |
|
40 |
# Lấy danh sách mô hình và kích thước
|
41 |
def get_model_names():
|
42 |
+
return [folder.name for folder in Path("./models").glob("*") if folder.is_dir()]
|
43 |
|
44 |
def get_sizes(model_name):
|
45 |
+
sizes = [folder.name for folder in (Path("./models") / model_name).iterdir() if folder.is_dir()]
|
46 |
return sizes
|
47 |
|
48 |
# Tạo giao diện Gradio
|