yeq6x commited on
Commit
ff3dcda
·
1 Parent(s): 7f1644b
Files changed (1) hide show
  1. app.py +18 -16
app.py CHANGED
@@ -19,6 +19,12 @@ image_size = 112
19
  batch_size = 32
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
 
 
 
 
 
 
22
  @spaces.GPU
23
  def load_model(model_path, feature_dim):
24
  model = AutoencoderModule(feature_dim=feature_dim)
@@ -141,21 +147,6 @@ def setup(model_info, input_image=None):
141
  fig = get_heatmaps(0, image_size // 2, image_size // 2, input_image)
142
  return fig
143
 
144
-
145
- models_info = [{"name": "ae_model_tf_2024-03-05_00-35-21.pth", "feature_dim": 32},
146
- {"name": "autoencoder-epoch=09-train_loss=1.00.ckpt", "feature_dim": 64},
147
- {"name": "autoencoder-epoch=29-train_loss=1.01.ckpt", "feature_dim": 64},
148
- {"name": "autoencoder-epoch=49-train_loss=1.01.ckpt", "feature_dim": 64}]
149
- models = []
150
- for model_info in models_info:
151
- model_name = model_info["name"]
152
- feature_dim = model_info["feature_dim"]
153
- model_path = f"checkpoints/{model_name}"
154
- model = load_model(model_path, feature_dim)
155
- models.append(model)
156
-
157
- setup(models_info[0])
158
-
159
  with gr.Blocks() as demo:
160
  # title
161
  gr.Markdown("# TripletGeoEncoder Feature Map Visualization")
@@ -204,7 +195,18 @@ with gr.Blocks() as demo:
204
  ],
205
  inputs=[input_image],
206
  )
207
-
 
 
 
 
 
 
 
 
 
 
 
208
  # JavaScriptコードをロード
209
  demo.launch()
210
 
 
19
  batch_size = 32
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
22
+ models_info = [{"name": "ae_model_tf_2024-03-05_00-35-21.pth", "feature_dim": 32},
23
+ {"name": "autoencoder-epoch=09-train_loss=1.00.ckpt", "feature_dim": 64},
24
+ {"name": "autoencoder-epoch=29-train_loss=1.01.ckpt", "feature_dim": 64},
25
+ {"name": "autoencoder-epoch=49-train_loss=1.01.ckpt", "feature_dim": 64}]
26
+ models = []
27
+
28
  @spaces.GPU
29
  def load_model(model_path, feature_dim):
30
  model = AutoencoderModule(feature_dim=feature_dim)
 
147
  fig = get_heatmaps(0, image_size // 2, image_size // 2, input_image)
148
  return fig
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  with gr.Blocks() as demo:
151
  # title
152
  gr.Markdown("# TripletGeoEncoder Feature Map Visualization")
 
195
  ],
196
  inputs=[input_image],
197
  )
198
+
199
+
200
+ if __name__ == "__main__":
201
+ for model_info in models_info:
202
+ model_name = model_info["name"]
203
+ feature_dim = model_info["feature_dim"]
204
+ model_path = f"checkpoints/{model_name}"
205
+ model = load_model(model_path, feature_dim)
206
+ models.append(model)
207
+
208
+ setup(models_info[0])
209
+
210
  # JavaScriptコードをロード
211
  demo.launch()
212