yeq6x commited on
Commit
fddf24a
·
1 Parent(s): 5886555
Files changed (1) hide show
  1. app.py +7 -8
app.py CHANGED
@@ -83,12 +83,12 @@ for model_info in models_info:
83
  model_name = model_info["name"]
84
  feature_dim = model_info["feature_dim"]
85
  model_path = f"checkpoints/{model_name}"
86
- model = load_model(model_path, feature_dim)
87
- models.append(model)
88
 
89
  x = load_data()
90
  test_imgs, points = load_keypoints()
91
- model = None
 
92
 
93
  # ヒートマップの生成関数
94
  @spaces.GPU
@@ -141,15 +141,14 @@ def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
141
 
142
  @spaces.GPU
143
  def setup(model_info, input_image=None):
144
- global model, mean_vector_list
145
  # str -> dictに変換
146
  if type(model_info) == str:
147
  model_info = eval(model_info)
148
-
149
- index = models_info.index(model_info)
150
- model = models[index]
151
 
152
- feature_map, _ = model(test_imgs)
 
 
153
  mean_vector_list = utils.get_mean_vector(feature_map, points)
154
 
155
  if input_image is not None:
 
83
  model_name = model_info["name"]
84
  feature_dim = model_info["feature_dim"]
85
  model_path = f"checkpoints/{model_name}"
86
+ models.append(load_model(model_path, feature_dim))
 
87
 
88
  x = load_data()
89
  test_imgs, points = load_keypoints()
90
+ mean_vector_list = []
91
+ model_index = 0
92
 
93
  # ヒートマップの生成関数
94
  @spaces.GPU
 
141
 
142
  @spaces.GPU
143
  def setup(model_info, input_image=None):
144
+ global model_index, mean_vector_list
145
  # str -> dictに変換
146
  if type(model_info) == str:
147
  model_info = eval(model_info)
 
 
 
148
 
149
+ model_index = models_info.index(model_info)
150
+
151
+ feature_map, _ = models[model_index](test_imgs)
152
  mean_vector_list = utils.get_mean_vector(feature_map, points)
153
 
154
  if input_image is not None: