Spaces:
Sleeping
Sleeping
gpu
Browse files
app.py
CHANGED
@@ -35,22 +35,6 @@ def load_model(model_path, feature_dim):
|
|
35 |
print(f"{model_path} loaded successfully.")
|
36 |
return model
|
37 |
|
38 |
-
image_size = 112
|
39 |
-
batch_size = 32
|
40 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
41 |
-
|
42 |
-
models_info = [{"name": "ae_model_tf_2024-03-05_00-35-21.pth", "feature_dim": 32},
|
43 |
-
{"name": "autoencoder-epoch=09-train_loss=1.00.ckpt", "feature_dim": 64},
|
44 |
-
{"name": "autoencoder-epoch=29-train_loss=1.01.ckpt", "feature_dim": 64},
|
45 |
-
{"name": "autoencoder-epoch=49-train_loss=1.01.ckpt", "feature_dim": 64}]
|
46 |
-
models = []
|
47 |
-
for model_info in models_info:
|
48 |
-
model_name = model_info["name"]
|
49 |
-
feature_dim = model_info["feature_dim"]
|
50 |
-
model_path = f"checkpoints/{model_name}"
|
51 |
-
model = load_model(model_path, feature_dim)
|
52 |
-
models.append(model)
|
53 |
-
|
54 |
def load_data(img_dir="resources/trainB/", image_size=112, batch_size=256):
|
55 |
filenames = load_filenames(img_dir)
|
56 |
train_X = filenames[:1000]
|
@@ -84,6 +68,25 @@ def load_keypoints(img_dir="resources/trainB/", image_size=112, batch_size=32):
|
|
84 |
print("Keypoints loaded successfully.")
|
85 |
return test_imgs, points
|
86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
# ヒートマップの生成関数
|
88 |
@spaces.GPU
|
89 |
def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
|
@@ -131,8 +134,9 @@ def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
|
|
131 |
plt.close(fig)
|
132 |
return fig
|
133 |
|
|
|
134 |
def setup(model_info, input_image=None):
|
135 |
-
global model,
|
136 |
# str -> dictに変換
|
137 |
if type(model_info) == str:
|
138 |
model_info = eval(model_info)
|
@@ -140,9 +144,6 @@ def setup(model_info, input_image=None):
|
|
140 |
index = models_info.index(model_info)
|
141 |
model = models[index]
|
142 |
|
143 |
-
x = load_data()
|
144 |
-
test_imgs, points = load_keypoints()
|
145 |
-
|
146 |
feature_map, _ = model(test_imgs)
|
147 |
mean_vector_list = utils.get_mean_vector(feature_map, points)
|
148 |
|
|
|
35 |
print(f"{model_path} loaded successfully.")
|
36 |
return model
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
def load_data(img_dir="resources/trainB/", image_size=112, batch_size=256):
|
39 |
filenames = load_filenames(img_dir)
|
40 |
train_X = filenames[:1000]
|
|
|
68 |
print("Keypoints loaded successfully.")
|
69 |
return test_imgs, points
|
70 |
|
71 |
+
image_size = 112
|
72 |
+
batch_size = 32
|
73 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
74 |
+
|
75 |
+
models_info = [{"name": "ae_model_tf_2024-03-05_00-35-21.pth", "feature_dim": 32},
|
76 |
+
{"name": "autoencoder-epoch=09-train_loss=1.00.ckpt", "feature_dim": 64},
|
77 |
+
{"name": "autoencoder-epoch=29-train_loss=1.01.ckpt", "feature_dim": 64},
|
78 |
+
{"name": "autoencoder-epoch=49-train_loss=1.01.ckpt", "feature_dim": 64}]
|
79 |
+
models = []
|
80 |
+
for model_info in models_info:
|
81 |
+
model_name = model_info["name"]
|
82 |
+
feature_dim = model_info["feature_dim"]
|
83 |
+
model_path = f"checkpoints/{model_name}"
|
84 |
+
model = load_model(model_path, feature_dim)
|
85 |
+
models.append(model)
|
86 |
+
|
87 |
+
x = load_data()
|
88 |
+
test_imgs, points = load_keypoints()
|
89 |
+
|
90 |
# ヒートマップの生成関数
|
91 |
@spaces.GPU
|
92 |
def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
|
|
|
134 |
plt.close(fig)
|
135 |
return fig
|
136 |
|
137 |
+
@spaces.GPU
|
138 |
def setup(model_info, input_image=None):
|
139 |
+
global model, mean_vector_list
|
140 |
# str -> dictに変換
|
141 |
if type(model_info) == str:
|
142 |
model_info = eval(model_info)
|
|
|
144 |
index = models_info.index(model_info)
|
145 |
model = models[index]
|
146 |
|
|
|
|
|
|
|
147 |
feature_map, _ = model(test_imgs)
|
148 |
mean_vector_list = utils.get_mean_vector(feature_map, points)
|
149 |
|