yeq6x commited on
Commit
156c303
·
1 Parent(s): 23281d5
Files changed (1) hide show
  1. app.py +22 -48
app.py CHANGED
@@ -125,31 +125,6 @@ try:
125
  plt.close(fig)
126
  return fig
127
 
128
- def setup(model_dict, input_image=None):
129
- global model, device, x, test_imgs, points, mean_vector_list
130
- # str -> dictに変換
131
- if type(model_dict) == str:
132
- model_dict = eval(model_dict)
133
- model_name = model_dict["name"]
134
- feature_dim = model_dict["feature_dim"]
135
- model_path = f"checkpoints/{model_name}"
136
- model, device = load_model(model_path, feature_dim)
137
- x = load_data(device)
138
- test_imgs, points = load_keypoints(device)
139
- feature_map, _ = model(test_imgs)
140
- mean_vector_list = utils.get_mean_vector(feature_map, points)
141
-
142
- if input_image is not None:
143
- fig = get_heatmaps(0, image_size // 2, image_size // 2, input_image)
144
- return fig
145
-
146
-
147
- models = [{"name": "ae_model_tf_2024-03-05_00-35-21.pth", "feature_dim": 32},
148
- {"name": "autoencoder-epoch=09-train_loss=1.00.ckpt", "feature_dim": 64},
149
- {"name": "autoencoder-epoch=29-train_loss=1.01.ckpt", "feature_dim": 64},
150
- {"name": "autoencoder-epoch=49-train_loss=1.01.ckpt", "feature_dim": 64}]
151
-
152
- setup(models[0])
153
  except:
154
  def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
155
  if type(uploaded_image) == str:
@@ -196,32 +171,31 @@ except:
196
  plt.close(fig)
197
  return fig
198
 
199
- def setup(model_dict, input_image=None):
200
- global model, device, x, test_imgs, points, mean_vector_list
201
- # str -> dictに変換
202
- if type(model_dict) == str:
203
- model_dict = eval(model_dict)
204
- model_name = model_dict["name"]
205
- feature_dim = model_dict["feature_dim"]
206
- model_path = f"checkpoints/{model_name}"
207
- model, device = load_model(model_path, feature_dim)
208
- x = load_data(device)
209
- test_imgs, points = load_keypoints(device)
210
- feature_map, _ = model(test_imgs)
211
- mean_vector_list = utils.get_mean_vector(feature_map, points)
 
 
 
 
212
 
213
- if input_image is not None:
214
- fig = get_heatmaps(0, image_size // 2, image_size // 2, input_image)
215
- return fig
216
-
217
-
218
- models = [{"name": "ae_model_tf_2024-03-05_00-35-21.pth", "feature_dim": 32},
219
- {"name": "autoencoder-epoch=09-train_loss=1.00.ckpt", "feature_dim": 64},
220
- {"name": "autoencoder-epoch=29-train_loss=1.01.ckpt", "feature_dim": 64},
221
- {"name": "autoencoder-epoch=49-train_loss=1.01.ckpt", "feature_dim": 64}]
222
 
223
- setup(models[0])
 
 
 
224
 
 
225
 
226
  with gr.Blocks() as demo:
227
  # title
 
125
  plt.close(fig)
126
  return fig
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  except:
129
  def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
130
  if type(uploaded_image) == str:
 
171
  plt.close(fig)
172
  return fig
173
 
174
+ def setup(model_dict, input_image=None):
175
+ global model, device, x, test_imgs, points, mean_vector_list
176
+ # str -> dictに変換
177
+ if type(model_dict) == str:
178
+ model_dict = eval(model_dict)
179
+ model_name = model_dict["name"]
180
+ feature_dim = model_dict["feature_dim"]
181
+ model_path = f"checkpoints/{model_name}"
182
+ model, device = load_model(model_path, feature_dim)
183
+ x = load_data(device)
184
+ test_imgs, points = load_keypoints(device)
185
+ feature_map, _ = model(test_imgs)
186
+ mean_vector_list = utils.get_mean_vector(feature_map, points)
187
+
188
+ if input_image is not None:
189
+ fig = get_heatmaps(0, image_size // 2, image_size // 2, input_image)
190
+ return fig
191
 
 
 
 
 
 
 
 
 
 
192
 
193
+ models = [{"name": "ae_model_tf_2024-03-05_00-35-21.pth", "feature_dim": 32},
194
+ {"name": "autoencoder-epoch=09-train_loss=1.00.ckpt", "feature_dim": 64},
195
+ {"name": "autoencoder-epoch=29-train_loss=1.01.ckpt", "feature_dim": 64},
196
+ {"name": "autoencoder-epoch=49-train_loss=1.01.ckpt", "feature_dim": 64}]
197
 
198
+ setup(models[0])
199
 
200
  with gr.Blocks() as demo:
201
  # title