Spaces:
Sleeping
Sleeping
gpu
Browse files
app.py
CHANGED
@@ -19,7 +19,7 @@ image_size = 112
|
|
19 |
batch_size = 32
|
20 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
21 |
|
22 |
-
def load_model(model_path
|
23 |
model = AutoencoderModule(feature_dim=feature_dim)
|
24 |
state_dict = torch.load(model_path)
|
25 |
|
@@ -39,7 +39,7 @@ def load_model(model_path="checkpoints/autoencoder-epoch=49-train_loss=1.01.ckpt
|
|
39 |
print("Model loaded successfully.")
|
40 |
return model
|
41 |
|
42 |
-
def load_data(
|
43 |
filenames = load_filenames(img_dir)
|
44 |
train_X = filenames[:1000]
|
45 |
|
@@ -55,7 +55,7 @@ def load_data(device, img_dir="resources/trainB/", image_size=112, batch_size=25
|
|
55 |
print("Data loaded successfully.")
|
56 |
return x
|
57 |
|
58 |
-
def load_keypoints(
|
59 |
filenames = load_filenames(img_dir)
|
60 |
train_X = filenames[:1000]
|
61 |
keypoints = dataset.load_keypoints('resources/DataList.json')
|
@@ -119,31 +119,39 @@ def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
|
|
119 |
plt.close(fig)
|
120 |
return fig
|
121 |
|
122 |
-
def setup(
|
123 |
global model, x, test_imgs, points, mean_vector_list
|
124 |
# str -> dictに変換
|
125 |
-
if type(
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
x = load_data(
|
132 |
-
test_imgs, points = load_keypoints(
|
|
|
133 |
feature_map, _ = model(test_imgs)
|
134 |
mean_vector_list = utils.get_mean_vector(feature_map, points)
|
135 |
|
136 |
if input_image is not None:
|
137 |
fig = get_heatmaps(0, image_size // 2, image_size // 2, input_image)
|
138 |
return fig
|
139 |
-
|
140 |
|
141 |
-
|
|
|
142 |
{"name": "autoencoder-epoch=09-train_loss=1.00.ckpt", "feature_dim": 64},
|
143 |
{"name": "autoencoder-epoch=29-train_loss=1.01.ckpt", "feature_dim": 64},
|
144 |
{"name": "autoencoder-epoch=49-train_loss=1.01.ckpt", "feature_dim": 64}]
|
145 |
-
|
146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
|
148 |
with gr.Blocks() as demo:
|
149 |
# title
|
|
|
19 |
batch_size = 32
|
20 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
21 |
|
22 |
+
def load_model(model_path, feature_dim):
|
23 |
model = AutoencoderModule(feature_dim=feature_dim)
|
24 |
state_dict = torch.load(model_path)
|
25 |
|
|
|
39 |
print("Model loaded successfully.")
|
40 |
return model
|
41 |
|
42 |
+
def load_data(img_dir="resources/trainB/", image_size=112, batch_size=256):
|
43 |
filenames = load_filenames(img_dir)
|
44 |
train_X = filenames[:1000]
|
45 |
|
|
|
55 |
print("Data loaded successfully.")
|
56 |
return x
|
57 |
|
58 |
+
def load_keypoints(img_dir="resources/trainB/", image_size=112, batch_size=32):
|
59 |
filenames = load_filenames(img_dir)
|
60 |
train_X = filenames[:1000]
|
61 |
keypoints = dataset.load_keypoints('resources/DataList.json')
|
|
|
119 |
plt.close(fig)
|
120 |
return fig
|
121 |
|
122 |
+
def setup(model_info, input_image=None):
|
123 |
global model, x, test_imgs, points, mean_vector_list
|
124 |
# str -> dictに変換
|
125 |
+
if type(model_info) == str:
|
126 |
+
model_info = eval(model_info)
|
127 |
+
|
128 |
+
index = models_info.index(model_info)
|
129 |
+
model = models[index]
|
130 |
+
|
131 |
+
x = load_data()
|
132 |
+
test_imgs, points = load_keypoints()
|
133 |
+
|
134 |
feature_map, _ = model(test_imgs)
|
135 |
mean_vector_list = utils.get_mean_vector(feature_map, points)
|
136 |
|
137 |
if input_image is not None:
|
138 |
fig = get_heatmaps(0, image_size // 2, image_size // 2, input_image)
|
139 |
return fig
|
|
|
140 |
|
141 |
+
|
142 |
+
models_info = [{"name": "ae_model_tf_2024-03-05_00-35-21.pth", "feature_dim": 32},
|
143 |
{"name": "autoencoder-epoch=09-train_loss=1.00.ckpt", "feature_dim": 64},
|
144 |
{"name": "autoencoder-epoch=29-train_loss=1.01.ckpt", "feature_dim": 64},
|
145 |
{"name": "autoencoder-epoch=49-train_loss=1.01.ckpt", "feature_dim": 64}]
|
146 |
+
models = []
|
147 |
+
for model_info in models_info:
|
148 |
+
model_name = model_info["name"]
|
149 |
+
feature_dim = model_info["feature_dim"]
|
150 |
+
model_path = f"checkpoints/{model_name}"
|
151 |
+
model = load_model(model_path, feature_dim)
|
152 |
+
models.append(model_name)
|
153 |
+
|
154 |
+
setup(models_info[0])
|
155 |
|
156 |
with gr.Blocks() as demo:
|
157 |
# title
|