chychiu commited on
Commit
b1be50b
·
1 Parent(s): e73e119
Files changed (1) hide show
  1. script.py +48 -37
script.py CHANGED
@@ -121,7 +121,7 @@ class ImageDataset(Dataset):
121
  image = cv2.imread(image_path)
122
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
123
 
124
- return self.transform(image=image)['image']
125
 
126
 
127
  class EmbeddingMetadataDataset(Dataset):
@@ -152,16 +152,20 @@ class EmbeddingMetadataDataset(Dataset):
152
 
153
  def generate_embeddings(metadata_file_path, root_dir):
154
 
 
 
155
  metadata_df = pd.read_csv(metadata_file_path)
156
 
157
  test_dataset = ImageDataset(metadata_df, local_filepath=root_dir)
158
 
159
- loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)
160
 
161
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
162
  model = timm.create_model(
163
- "timm/vit_large_patch14_reg4_dinov2.lvd142m", pretrained=True
164
  )
 
 
165
 
166
  model = model.to(device)
167
  model.eval()
@@ -255,9 +259,7 @@ class FungiMEEModel(nn.Module):
255
  geo_emb = self.geo_embedding.forward(metadata["geo"].to(self.device))
256
  substr_emb = self.substr_embedding.forward(metadata["substr"].to(self.device))
257
 
258
- full_emb = torch.stack(
259
- (img_emb, date_emb, geo_emb, substr_emb), dim=1
260
- )
261
 
262
  cls_emb = self.encoder.forward(full_emb)[:, 0, :].squeeze(1)
263
 
@@ -305,37 +307,46 @@ class FungiEnsembleModel(nn.Module):
305
 
306
  return logits.argmax(1).tolist()
307
 
 
308
  def make_submission(metadata_df):
309
 
310
  OUTPUT_CSV_PATH = "./submission.csv"
311
  BASE_CKPT_PATH = "./checkpoints"
312
 
313
- ckpt_path = os.path.join(BASE_CKPT_PATH, "dino_2_optuna_05242055.ckpt")
314
- # model_names = os.listdir(BASE_CKPT_PATH)
315
-
316
- # models = []
317
-
318
- # for model_path in model_names:
319
- # print("loading ", model_path)
320
- # ckpt_path = os.path.join(BASE_CKPT_PATH, model_path)
321
-
322
- # ckpt = torch.load(ckpt_path)
323
- # model = FungiMEEModel()
324
- # model.load_state_dict(
325
- # {w: ckpt["model." + w] for w in model.state_dict().keys()}
326
- # )
327
- # model.eval()
328
- # model.cuda()
 
 
 
 
 
 
 
 
329
 
330
- # models.append(model)
331
 
332
- # fungi_model = FungiEnsembleModel(models)
333
 
334
- fungi_model = FungiMEEModel()
335
- ckpt = torch.load(ckpt_path)
336
- fungi_model.load_state_dict(
337
- {w: ckpt["model." + w] for w in fungi_model.state_dict().keys()}
338
- )
339
 
340
  embedding_dataset = EmbeddingMetadataDataset(metadata_df)
341
  loader = DataLoader(embedding_dataset, batch_size=128, shuffle=False)
@@ -369,18 +380,18 @@ if __name__ == "__main__":
369
  MODEL_PATH = "metaformer-s-224.pth"
370
  MODEL_NAME = "timm/vit_base_patch14_reg4_dinov2.lvd142m"
371
 
372
- # # # # Real submission
373
- # import zipfile
374
 
375
- # with zipfile.ZipFile("/tmp/data/private_testset.zip", "r") as zip_ref:
376
- # zip_ref.extractall("/tmp/data")
377
 
378
- # metadata_file_path = "./_test_preprocessed.csv"
379
- # root_dir = "/tmp/data"
380
 
381
  # Test submission
382
- metadata_file_path = "../trial_submission.csv"
383
- root_dir = "../data/DF_FULL"
384
 
385
  ##############
386
 
 
121
  image = cv2.imread(image_path)
122
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
123
 
124
+ return self.transform(image=image)["image"]
125
 
126
 
127
  class EmbeddingMetadataDataset(Dataset):
 
152
 
153
  def generate_embeddings(metadata_file_path, root_dir):
154
 
155
+ DINOV2_CKPT = "./checkpoints/dinov2.bin"
156
+
157
  metadata_df = pd.read_csv(metadata_file_path)
158
 
159
  test_dataset = ImageDataset(metadata_df, local_filepath=root_dir)
160
 
161
+ loader = DataLoader(test_dataset, batch_size=2, shuffle=False, num_workers=4)
162
 
163
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
164
  model = timm.create_model(
165
+ "timm/vit_large_patch14_reg4_dinov2.lvd142m", pretrained=False
166
  )
167
+ weights = torch.load(DINOV2_CKPT)
168
+ model.load_state_dict(weights)
169
 
170
  model = model.to(device)
171
  model.eval()
 
259
  geo_emb = self.geo_embedding.forward(metadata["geo"].to(self.device))
260
  substr_emb = self.substr_embedding.forward(metadata["substr"].to(self.device))
261
 
262
+ full_emb = torch.stack((img_emb, date_emb, geo_emb, substr_emb), dim=1)
 
 
263
 
264
  cls_emb = self.encoder.forward(full_emb)[:, 0, :].squeeze(1)
265
 
 
307
 
308
  return logits.argmax(1).tolist()
309
 
310
+
311
  def make_submission(metadata_df):
312
 
313
  OUTPUT_CSV_PATH = "./submission.csv"
314
  BASE_CKPT_PATH = "./checkpoints"
315
 
316
+ # ckpt_path = os.path.join(BASE_CKPT_PATH, "dino_2_optuna_05242055.ckpt")
317
+ model_names = [
318
+ "dino_2_optuna_05242231.ckpt",
319
+ "dino_optuna_05241449.ckpt",
320
+ "dino_optuna_05241257.ckpt",
321
+ "dino_optuna_05241222.ckpt",
322
+ "dino_2_optuna_05242055.ckpt",
323
+ "dino_2_optuna_05242156.ckpt",
324
+ "dino_2_optuna_05242344.ckpt",
325
+ ]
326
+
327
+ models = []
328
+
329
+ for model_path in model_names:
330
+ print("loading ", model_path)
331
+ ckpt_path = os.path.join(BASE_CKPT_PATH, model_path)
332
+
333
+ ckpt = torch.load(ckpt_path)
334
+ model = FungiMEEModel()
335
+ model.load_state_dict(
336
+ {w: ckpt["model." + w] for w in model.state_dict().keys()}
337
+ )
338
+ model.eval()
339
+ model.cuda()
340
 
341
+ models.append(model)
342
 
343
+ fungi_model = FungiEnsembleModel(models)
344
 
345
+ # fungi_model = FungiMEEModel()
346
+ # ckpt = torch.load(ckpt_path)
347
+ # fungi_model.load_state_dict(
348
+ # {w: ckpt["model." + w] for w in fungi_model.state_dict().keys()}
349
+ # )
350
 
351
  embedding_dataset = EmbeddingMetadataDataset(metadata_df)
352
  loader = DataLoader(embedding_dataset, batch_size=128, shuffle=False)
 
380
  MODEL_PATH = "metaformer-s-224.pth"
381
  MODEL_NAME = "timm/vit_base_patch14_reg4_dinov2.lvd142m"
382
 
383
+ # # # # # Real submission
384
+ import zipfile
385
 
386
+ with zipfile.ZipFile("/tmp/data/private_testset.zip", "r") as zip_ref:
387
+ zip_ref.extractall("/tmp/data")
388
 
389
+ metadata_file_path = "./_test_preprocessed.csv"
390
+ root_dir = "/tmp/data"
391
 
392
  # Test submission
393
+ # metadata_file_path = "../trial_submission.csv"
394
+ # root_dir = "../data/DF_FULL"
395
 
396
  ##############
397