chychiu commited on
Commit
be31ee2
·
1 Parent(s): 96597c6

single batch, single model

Browse files
Files changed (1) hide show
  1. script.py +33 -35
script.py CHANGED
@@ -158,7 +158,7 @@ def generate_embeddings(metadata_file_path, root_dir):
158
 
159
  test_dataset = ImageDataset(metadata_df, local_filepath=root_dir)
160
 
161
- loader = DataLoader(test_dataset, batch_size=2, shuffle=False, num_workers=4)
162
 
163
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
164
  model = timm.create_model(
@@ -313,40 +313,41 @@ def make_submission(metadata_df):
313
  OUTPUT_CSV_PATH = "./submission.csv"
314
  BASE_CKPT_PATH = "./checkpoints"
315
 
316
- # ckpt_path = os.path.join(BASE_CKPT_PATH, "dino_2_optuna_05242055.ckpt")
317
- model_names = [
318
- "dino_2_optuna_05242231.ckpt",
319
- "dino_optuna_05241449.ckpt",
320
- "dino_optuna_05241257.ckpt",
321
- "dino_optuna_05241222.ckpt",
322
- "dino_2_optuna_05242055.ckpt",
323
- "dino_2_optuna_05242156.ckpt",
324
- "dino_2_optuna_05242344.ckpt",
325
- ]
326
-
327
- models = []
328
-
329
- for model_path in model_names:
330
- print("loading ", model_path)
331
- ckpt_path = os.path.join(BASE_CKPT_PATH, model_path)
332
-
333
- ckpt = torch.load(ckpt_path)
334
- model = FungiMEEModel()
335
- model.load_state_dict(
336
- {w: ckpt["model." + w] for w in model.state_dict().keys()}
337
- )
338
- model.eval()
339
- model.cuda()
340
 
341
- models.append(model)
342
 
343
- fungi_model = FungiEnsembleModel(models)
344
 
345
- # fungi_model = FungiMEEModel()
346
- # ckpt = torch.load(ckpt_path)
347
- # fungi_model.load_state_dict(
348
- # {w: ckpt["model." + w] for w in fungi_model.state_dict().keys()}
349
- # )
 
 
350
 
351
  embedding_dataset = EmbeddingMetadataDataset(metadata_df)
352
  loader = DataLoader(embedding_dataset, batch_size=128, shuffle=False)
@@ -377,9 +378,6 @@ def make_submission(metadata_df):
377
 
378
  if __name__ == "__main__":
379
 
380
- MODEL_PATH = "metaformer-s-224.pth"
381
- MODEL_NAME = "timm/vit_base_patch14_reg4_dinov2.lvd142m"
382
-
383
  # # # # # # Real submission
384
  import zipfile
385
 
 
158
 
159
  test_dataset = ImageDataset(metadata_df, local_filepath=root_dir)
160
 
161
+ loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)
162
 
163
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
164
  model = timm.create_model(
 
313
  OUTPUT_CSV_PATH = "./submission.csv"
314
  BASE_CKPT_PATH = "./checkpoints"
315
 
316
+ # model_names = [
317
+ # "dino_2_optuna_05242231.ckpt",
318
+ # "dino_optuna_05241449.ckpt",
319
+ # "dino_optuna_05241257.ckpt",
320
+ # "dino_optuna_05241222.ckpt",
321
+ # "dino_2_optuna_05242055.ckpt",
322
+ # "dino_2_optuna_05242156.ckpt",
323
+ # "dino_2_optuna_05242344.ckpt",
324
+ # ]
325
+
326
+ # models = []
327
+
328
+ # for model_path in model_names:
329
+ # print("loading ", model_path)
330
+ # ckpt_path = os.path.join(BASE_CKPT_PATH, model_path)
331
+
332
+ # ckpt = torch.load(ckpt_path)
333
+ # model = FungiMEEModel()
334
+ # model.load_state_dict(
335
+ # {w: ckpt["model." + w] for w in model.state_dict().keys()}
336
+ # )
337
+ # model.eval()
338
+ # model.cuda()
 
339
 
340
+ # models.append(model)
341
 
342
+ # fungi_model = FungiEnsembleModel(models)
343
 
344
+ ckpt_path = os.path.join(BASE_CKPT_PATH, "dino_2_optuna_05242055.ckpt")
345
+
346
+ fungi_model = FungiMEEModel()
347
+ ckpt = torch.load(ckpt_path)
348
+ fungi_model.load_state_dict(
349
+ {w: ckpt["model." + w] for w in fungi_model.state_dict().keys()}
350
+ )
351
 
352
  embedding_dataset = EmbeddingMetadataDataset(metadata_df)
353
  loader = DataLoader(embedding_dataset, batch_size=128, shuffle=False)
 
378
 
379
  if __name__ == "__main__":
380
 
 
 
 
381
  # # # # # # Real submission
382
  import zipfile
383