chychiu commited on
Commit
e73e119
·
1 Parent(s): 355e661

added dinov2 weights

Browse files
Files changed (2) hide show
  1. checkpoints/dinov2.bin +3 -0
  2. script.py +45 -39
checkpoints/dinov2.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78971dc00a0c488f2b2dff17d6dcb7ebe787af70a703d8212b38fc6a33dbcdd4
3
+ size 1217608166
script.py CHANGED
@@ -4,12 +4,13 @@ import pandas as pd
4
  import timm
5
  import torch
6
  import torch.nn as nn
7
- import torch.nn.functional as F
8
- import torchvision.transforms as T
9
  from PIL import Image
10
  from timm.models.metaformer import MlpHead
11
  from torch.utils.data import DataLoader, Dataset
12
  from tqdm import tqdm
 
 
 
13
 
14
  DIM = 518
15
  DATE_SIZE = 4
@@ -99,11 +100,11 @@ SUBSTRATE = [
99
  class ImageDataset(Dataset):
100
  def __init__(self, df, local_filepath):
101
  self.df = df
102
- self.transform = T.Compose(
103
  [
104
- T.Resize((DIM, DIM)),
105
- T.ToTensor(),
106
- T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
107
  ]
108
  )
109
 
@@ -117,9 +118,10 @@ class ImageDataset(Dataset):
117
  def __getitem__(self, idx):
118
  image_path = os.path.join(self.local_filepath, self.filepaths[idx])
119
 
120
- image = Image.open(image_path).convert("RGB")
 
121
 
122
- return self.transform(image)
123
 
124
 
125
  class EmbeddingMetadataDataset(Dataset):
@@ -270,11 +272,10 @@ class FungiMEEModel(nn.Module):
270
 
271
  class FungiEnsembleModel(nn.Module):
272
 
273
- def __init__(self, models, softmax=True) -> None:
274
  super().__init__()
275
 
276
  self.models = nn.ModuleList()
277
- self.softmax = softmax
278
  self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
279
 
280
  for model in models:
@@ -291,12 +292,7 @@ class FungiEnsembleModel(nn.Module):
291
  for model in self.models:
292
  logits = model.forward(img_emb, metadata)
293
 
294
- p = (
295
- logits.softmax(dim=1).detach().cpu()
296
- if self.softmax
297
- else logits.detach().cpu()
298
- )
299
-
300
  probs.append(p)
301
 
302
  return torch.stack(probs).mean(dim=0)
@@ -314,25 +310,32 @@ def make_submission(metadata_df):
314
  OUTPUT_CSV_PATH = "./submission.csv"
315
  BASE_CKPT_PATH = "./checkpoints"
316
 
317
- model_names = os.listdir(BASE_CKPT_PATH)
 
318
 
319
- models = []
320
 
321
- for model_path in model_names:
322
- print("loading ", model_path)
323
- ckpt_path = os.path.join(BASE_CKPT_PATH, model_path)
324
 
325
- ckpt = torch.load(ckpt_path)
326
- model = FungiMEEModel()
327
- model.load_state_dict(
328
- {w: ckpt["model." + w] for w in model.state_dict().keys()}
329
- )
330
- model.eval()
331
- model.cuda()
332
 
333
- models.append(model)
334
 
335
- ensemble_model = FungiEnsembleModel(models)
 
 
 
 
 
 
336
 
337
  embedding_dataset = EmbeddingMetadataDataset(metadata_df)
338
  loader = DataLoader(embedding_dataset, batch_size=128, shuffle=False)
@@ -340,7 +343,7 @@ def make_submission(metadata_df):
340
  preds = []
341
  for data in tqdm(loader):
342
  emb, metadata = data
343
- pred = ensemble_model.forward(emb, metadata)
344
  preds.append(pred)
345
 
346
  all_preds = torch.vstack(preds).numpy()
@@ -363,18 +366,21 @@ def make_submission(metadata_df):
363
 
364
  if __name__ == "__main__":
365
 
366
- # # # Real submission
367
- import zipfile
 
 
 
368
 
369
- with zipfile.ZipFile("/tmp/data/private_testset.zip", "r") as zip_ref:
370
- zip_ref.extractall("/tmp/data")
371
 
372
- metadata_file_path = "./_test_preprocessed.csv"
373
- root_dir = "/tmp/data"
374
 
375
  # Test submission
376
- # metadata_file_path = "../trial_submission.csv"
377
- # root_dir = "../data/DF_FULL"
378
 
379
  ##############
380
 
 
4
  import timm
5
  import torch
6
  import torch.nn as nn
 
 
7
  from PIL import Image
8
  from timm.models.metaformer import MlpHead
9
  from torch.utils.data import DataLoader, Dataset
10
  from tqdm import tqdm
11
+ from albumentations import Compose, Normalize, Resize
12
+ from albumentations.pytorch import ToTensorV2
13
+ import cv2
14
 
15
  DIM = 518
16
  DATE_SIZE = 4
 
100
  class ImageDataset(Dataset):
101
  def __init__(self, df, local_filepath):
102
  self.df = df
103
+ self.transform = Compose(
104
  [
105
+ Resize(DIM, DIM),
106
+ Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
107
+ ToTensorV2(),
108
  ]
109
  )
110
 
 
118
  def __getitem__(self, idx):
119
  image_path = os.path.join(self.local_filepath, self.filepaths[idx])
120
 
121
+ image = cv2.imread(image_path)
122
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
123
 
124
+ return self.transform(image=image)['image']
125
 
126
 
127
  class EmbeddingMetadataDataset(Dataset):
 
272
 
273
  class FungiEnsembleModel(nn.Module):
274
 
275
+ def __init__(self, models) -> None:
276
  super().__init__()
277
 
278
  self.models = nn.ModuleList()
 
279
  self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
280
 
281
  for model in models:
 
292
  for model in self.models:
293
  logits = model.forward(img_emb, metadata)
294
 
295
+ p = logits.softmax(dim=1).detach().cpu()
 
 
 
 
 
296
  probs.append(p)
297
 
298
  return torch.stack(probs).mean(dim=0)
 
310
  OUTPUT_CSV_PATH = "./submission.csv"
311
  BASE_CKPT_PATH = "./checkpoints"
312
 
313
+ ckpt_path = os.path.join(BASE_CKPT_PATH, "dino_2_optuna_05242055.ckpt")
314
+ # model_names = os.listdir(BASE_CKPT_PATH)
315
 
316
+ # models = []
317
 
318
+ # for model_path in model_names:
319
+ # print("loading ", model_path)
320
+ # ckpt_path = os.path.join(BASE_CKPT_PATH, model_path)
321
 
322
+ # ckpt = torch.load(ckpt_path)
323
+ # model = FungiMEEModel()
324
+ # model.load_state_dict(
325
+ # {w: ckpt["model." + w] for w in model.state_dict().keys()}
326
+ # )
327
+ # model.eval()
328
+ # model.cuda()
329
 
330
+ # models.append(model)
331
 
332
+ # fungi_model = FungiEnsembleModel(models)
333
+
334
+ fungi_model = FungiMEEModel()
335
+ ckpt = torch.load(ckpt_path)
336
+ fungi_model.load_state_dict(
337
+ {w: ckpt["model." + w] for w in fungi_model.state_dict().keys()}
338
+ )
339
 
340
  embedding_dataset = EmbeddingMetadataDataset(metadata_df)
341
  loader = DataLoader(embedding_dataset, batch_size=128, shuffle=False)
 
343
  preds = []
344
  for data in tqdm(loader):
345
  emb, metadata = data
346
+ pred = fungi_model.forward(emb, metadata)
347
  preds.append(pred)
348
 
349
  all_preds = torch.vstack(preds).numpy()
 
366
 
367
  if __name__ == "__main__":
368
 
369
+ MODEL_PATH = "metaformer-s-224.pth"
370
+ MODEL_NAME = "timm/vit_base_patch14_reg4_dinov2.lvd142m"
371
+
372
+ # # # # Real submission
373
+ # import zipfile
374
 
375
+ # with zipfile.ZipFile("/tmp/data/private_testset.zip", "r") as zip_ref:
376
+ # zip_ref.extractall("/tmp/data")
377
 
378
+ # metadata_file_path = "./_test_preprocessed.csv"
379
+ # root_dir = "/tmp/data"
380
 
381
  # Test submission
382
+ metadata_file_path = "../trial_submission.csv"
383
+ root_dir = "../data/DF_FULL"
384
 
385
  ##############
386