yeq6x commited on
Commit
26ba1d3
·
1 Parent(s): 50f5011
Files changed (1) hide show
  1. app.py +24 -16
app.py CHANGED
@@ -19,7 +19,7 @@ image_size = 112
19
  batch_size = 32
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
22
- def load_model(model_path="checkpoints/autoencoder-epoch=49-train_loss=1.01.ckpt", feature_dim=64):
23
  model = AutoencoderModule(feature_dim=feature_dim)
24
  state_dict = torch.load(model_path)
25
 
@@ -39,7 +39,7 @@ def load_model(model_path="checkpoints/autoencoder-epoch=49-train_loss=1.01.ckpt
39
  print("Model loaded successfully.")
40
  return model
41
 
42
- def load_data(device, img_dir="resources/trainB/", image_size=112, batch_size=256):
43
  filenames = load_filenames(img_dir)
44
  train_X = filenames[:1000]
45
 
@@ -55,7 +55,7 @@ def load_data(device, img_dir="resources/trainB/", image_size=112, batch_size=25
55
  print("Data loaded successfully.")
56
  return x
57
 
58
- def load_keypoints(device, img_dir="resources/trainB/", image_size=112, batch_size=32):
59
  filenames = load_filenames(img_dir)
60
  train_X = filenames[:1000]
61
  keypoints = dataset.load_keypoints('resources/DataList.json')
@@ -119,31 +119,39 @@ def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
119
  plt.close(fig)
120
  return fig
121
 
122
- def setup(model_dict, input_image=None):
123
  global model, x, test_imgs, points, mean_vector_list
124
  # str -> dictに変換
125
- if type(model_dict) == str:
126
- model_dict = eval(model_dict)
127
- model_name = model_dict["name"]
128
- feature_dim = model_dict["feature_dim"]
129
- model_path = f"checkpoints/{model_name}"
130
- model = load_model(model_path, feature_dim)
131
- x = load_data(device)
132
- test_imgs, points = load_keypoints(device)
 
133
  feature_map, _ = model(test_imgs)
134
  mean_vector_list = utils.get_mean_vector(feature_map, points)
135
 
136
  if input_image is not None:
137
  fig = get_heatmaps(0, image_size // 2, image_size // 2, input_image)
138
  return fig
139
-
140
 
141
- models = [{"name": "ae_model_tf_2024-03-05_00-35-21.pth", "feature_dim": 32},
 
142
  {"name": "autoencoder-epoch=09-train_loss=1.00.ckpt", "feature_dim": 64},
143
  {"name": "autoencoder-epoch=29-train_loss=1.01.ckpt", "feature_dim": 64},
144
  {"name": "autoencoder-epoch=49-train_loss=1.01.ckpt", "feature_dim": 64}]
145
-
146
- setup(models[0])
 
 
 
 
 
 
 
147
 
148
  with gr.Blocks() as demo:
149
  # title
 
19
  batch_size = 32
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
22
+ def load_model(model_path, feature_dim):
23
  model = AutoencoderModule(feature_dim=feature_dim)
24
  state_dict = torch.load(model_path)
25
 
 
39
  print("Model loaded successfully.")
40
  return model
41
 
42
+ def load_data(img_dir="resources/trainB/", image_size=112, batch_size=256):
43
  filenames = load_filenames(img_dir)
44
  train_X = filenames[:1000]
45
 
 
55
  print("Data loaded successfully.")
56
  return x
57
 
58
+ def load_keypoints(img_dir="resources/trainB/", image_size=112, batch_size=32):
59
  filenames = load_filenames(img_dir)
60
  train_X = filenames[:1000]
61
  keypoints = dataset.load_keypoints('resources/DataList.json')
 
119
  plt.close(fig)
120
  return fig
121
 
122
+ def setup(model_info, input_image=None):
123
  global model, x, test_imgs, points, mean_vector_list
124
  # str -> dictに変換
125
+ if type(model_info) == str:
126
+ model_info = eval(model_info)
127
+
128
+ index = models_info.index(model_info)
129
+ model = models[index]
130
+
131
+ x = load_data()
132
+ test_imgs, points = load_keypoints()
133
+
134
  feature_map, _ = model(test_imgs)
135
  mean_vector_list = utils.get_mean_vector(feature_map, points)
136
 
137
  if input_image is not None:
138
  fig = get_heatmaps(0, image_size // 2, image_size // 2, input_image)
139
  return fig
 
140
 
141
+
142
+ models_info = [{"name": "ae_model_tf_2024-03-05_00-35-21.pth", "feature_dim": 32},
143
  {"name": "autoencoder-epoch=09-train_loss=1.00.ckpt", "feature_dim": 64},
144
  {"name": "autoencoder-epoch=29-train_loss=1.01.ckpt", "feature_dim": 64},
145
  {"name": "autoencoder-epoch=49-train_loss=1.01.ckpt", "feature_dim": 64}]
146
+ models = []
147
+ for model_info in models_info:
148
+ model_name = model_info["name"]
149
+ feature_dim = model_info["feature_dim"]
150
+ model_path = f"checkpoints/{model_name}"
151
+ model = load_model(model_path, feature_dim)
152
+ models.append(model_name)
153
+
154
+ setup(models_info[0])
155
 
156
  with gr.Blocks() as demo:
157
  # title