yeq6x commited on
Commit
0c6ebb1
·
1 Parent(s): 00e1057
Files changed (1) hide show
  1. app.py +21 -20
app.py CHANGED
@@ -35,22 +35,6 @@ def load_model(model_path, feature_dim):
35
  print(f"{model_path} loaded successfully.")
36
  return model
37
 
38
- image_size = 112
39
- batch_size = 32
40
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
-
42
- models_info = [{"name": "ae_model_tf_2024-03-05_00-35-21.pth", "feature_dim": 32},
43
- {"name": "autoencoder-epoch=09-train_loss=1.00.ckpt", "feature_dim": 64},
44
- {"name": "autoencoder-epoch=29-train_loss=1.01.ckpt", "feature_dim": 64},
45
- {"name": "autoencoder-epoch=49-train_loss=1.01.ckpt", "feature_dim": 64}]
46
- models = []
47
- for model_info in models_info:
48
- model_name = model_info["name"]
49
- feature_dim = model_info["feature_dim"]
50
- model_path = f"checkpoints/{model_name}"
51
- model = load_model(model_path, feature_dim)
52
- models.append(model)
53
-
54
  def load_data(img_dir="resources/trainB/", image_size=112, batch_size=256):
55
  filenames = load_filenames(img_dir)
56
  train_X = filenames[:1000]
@@ -84,6 +68,25 @@ def load_keypoints(img_dir="resources/trainB/", image_size=112, batch_size=32):
84
  print("Keypoints loaded successfully.")
85
  return test_imgs, points
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  # ヒートマップの生成関数
88
  @spaces.GPU
89
  def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
@@ -131,8 +134,9 @@ def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
131
  plt.close(fig)
132
  return fig
133
 
 
134
  def setup(model_info, input_image=None):
135
- global model, x, test_imgs, points, mean_vector_list
136
  # str -> dictに変換
137
  if type(model_info) == str:
138
  model_info = eval(model_info)
@@ -140,9 +144,6 @@ def setup(model_info, input_image=None):
140
  index = models_info.index(model_info)
141
  model = models[index]
142
 
143
- x = load_data()
144
- test_imgs, points = load_keypoints()
145
-
146
  feature_map, _ = model(test_imgs)
147
  mean_vector_list = utils.get_mean_vector(feature_map, points)
148
 
 
35
  print(f"{model_path} loaded successfully.")
36
  return model
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def load_data(img_dir="resources/trainB/", image_size=112, batch_size=256):
39
  filenames = load_filenames(img_dir)
40
  train_X = filenames[:1000]
 
68
  print("Keypoints loaded successfully.")
69
  return test_imgs, points
70
 
71
+ image_size = 112
72
+ batch_size = 32
73
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
74
+
75
+ models_info = [{"name": "ae_model_tf_2024-03-05_00-35-21.pth", "feature_dim": 32},
76
+ {"name": "autoencoder-epoch=09-train_loss=1.00.ckpt", "feature_dim": 64},
77
+ {"name": "autoencoder-epoch=29-train_loss=1.01.ckpt", "feature_dim": 64},
78
+ {"name": "autoencoder-epoch=49-train_loss=1.01.ckpt", "feature_dim": 64}]
79
+ models = []
80
+ for model_info in models_info:
81
+ model_name = model_info["name"]
82
+ feature_dim = model_info["feature_dim"]
83
+ model_path = f"checkpoints/{model_name}"
84
+ model = load_model(model_path, feature_dim)
85
+ models.append(model)
86
+
87
+ x = load_data()
88
+ test_imgs, points = load_keypoints()
89
+
90
  # ヒートマップの生成関数
91
  @spaces.GPU
92
  def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
 
134
  plt.close(fig)
135
  return fig
136
 
137
+ @spaces.GPU
138
  def setup(model_info, input_image=None):
139
+ global model, mean_vector_list
140
  # str -> dictに変換
141
  if type(model_info) == str:
142
  model_info = eval(model_info)
 
144
  index = models_info.index(model_info)
145
  model = models[index]
146
 
 
 
 
147
  feature_map, _ = model(test_imgs)
148
  mean_vector_list = utils.get_mean_vector(feature_map, points)
149