added dinov2 weights
Browse files- checkpoints/dinov2.bin +3 -0
- script.py +45 -39
checkpoints/dinov2.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:78971dc00a0c488f2b2dff17d6dcb7ebe787af70a703d8212b38fc6a33dbcdd4
|
3 |
+
size 1217608166
|
script.py
CHANGED
@@ -4,12 +4,13 @@ import pandas as pd
|
|
4 |
import timm
|
5 |
import torch
|
6 |
import torch.nn as nn
|
7 |
-
import torch.nn.functional as F
|
8 |
-
import torchvision.transforms as T
|
9 |
from PIL import Image
|
10 |
from timm.models.metaformer import MlpHead
|
11 |
from torch.utils.data import DataLoader, Dataset
|
12 |
from tqdm import tqdm
|
|
|
|
|
|
|
13 |
|
14 |
DIM = 518
|
15 |
DATE_SIZE = 4
|
@@ -99,11 +100,11 @@ SUBSTRATE = [
|
|
99 |
class ImageDataset(Dataset):
|
100 |
def __init__(self, df, local_filepath):
|
101 |
self.df = df
|
102 |
-
self.transform =
|
103 |
[
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
]
|
108 |
)
|
109 |
|
@@ -117,9 +118,10 @@ class ImageDataset(Dataset):
|
|
117 |
def __getitem__(self, idx):
|
118 |
image_path = os.path.join(self.local_filepath, self.filepaths[idx])
|
119 |
|
120 |
-
image =
|
|
|
121 |
|
122 |
-
return self.transform(image)
|
123 |
|
124 |
|
125 |
class EmbeddingMetadataDataset(Dataset):
|
@@ -270,11 +272,10 @@ class FungiMEEModel(nn.Module):
|
|
270 |
|
271 |
class FungiEnsembleModel(nn.Module):
|
272 |
|
273 |
-
def __init__(self, models
|
274 |
super().__init__()
|
275 |
|
276 |
self.models = nn.ModuleList()
|
277 |
-
self.softmax = softmax
|
278 |
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
279 |
|
280 |
for model in models:
|
@@ -291,12 +292,7 @@ class FungiEnsembleModel(nn.Module):
|
|
291 |
for model in self.models:
|
292 |
logits = model.forward(img_emb, metadata)
|
293 |
|
294 |
-
p = (
|
295 |
-
logits.softmax(dim=1).detach().cpu()
|
296 |
-
if self.softmax
|
297 |
-
else logits.detach().cpu()
|
298 |
-
)
|
299 |
-
|
300 |
probs.append(p)
|
301 |
|
302 |
return torch.stack(probs).mean(dim=0)
|
@@ -314,25 +310,32 @@ def make_submission(metadata_df):
|
|
314 |
OUTPUT_CSV_PATH = "./submission.csv"
|
315 |
BASE_CKPT_PATH = "./checkpoints"
|
316 |
|
317 |
-
|
|
|
318 |
|
319 |
-
models = []
|
320 |
|
321 |
-
for model_path in model_names:
|
322 |
-
|
323 |
-
|
324 |
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
|
333 |
-
|
334 |
|
335 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
336 |
|
337 |
embedding_dataset = EmbeddingMetadataDataset(metadata_df)
|
338 |
loader = DataLoader(embedding_dataset, batch_size=128, shuffle=False)
|
@@ -340,7 +343,7 @@ def make_submission(metadata_df):
|
|
340 |
preds = []
|
341 |
for data in tqdm(loader):
|
342 |
emb, metadata = data
|
343 |
-
pred =
|
344 |
preds.append(pred)
|
345 |
|
346 |
all_preds = torch.vstack(preds).numpy()
|
@@ -363,18 +366,21 @@ def make_submission(metadata_df):
|
|
363 |
|
364 |
if __name__ == "__main__":
|
365 |
|
366 |
-
|
367 |
-
|
|
|
|
|
|
|
368 |
|
369 |
-
with zipfile.ZipFile("/tmp/data/private_testset.zip", "r") as zip_ref:
|
370 |
-
|
371 |
|
372 |
-
metadata_file_path = "./_test_preprocessed.csv"
|
373 |
-
root_dir = "/tmp/data"
|
374 |
|
375 |
# Test submission
|
376 |
-
|
377 |
-
|
378 |
|
379 |
##############
|
380 |
|
|
|
4 |
import timm
|
5 |
import torch
|
6 |
import torch.nn as nn
|
|
|
|
|
7 |
from PIL import Image
|
8 |
from timm.models.metaformer import MlpHead
|
9 |
from torch.utils.data import DataLoader, Dataset
|
10 |
from tqdm import tqdm
|
11 |
+
from albumentations import Compose, Normalize, Resize
|
12 |
+
from albumentations.pytorch import ToTensorV2
|
13 |
+
import cv2
|
14 |
|
15 |
DIM = 518
|
16 |
DATE_SIZE = 4
|
|
|
100 |
class ImageDataset(Dataset):
|
101 |
def __init__(self, df, local_filepath):
|
102 |
self.df = df
|
103 |
+
self.transform = Compose(
|
104 |
[
|
105 |
+
Resize(DIM, DIM),
|
106 |
+
Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
107 |
+
ToTensorV2(),
|
108 |
]
|
109 |
)
|
110 |
|
|
|
118 |
def __getitem__(self, idx):
|
119 |
image_path = os.path.join(self.local_filepath, self.filepaths[idx])
|
120 |
|
121 |
+
image = cv2.imread(image_path)
|
122 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
123 |
|
124 |
+
return self.transform(image=image)['image']
|
125 |
|
126 |
|
127 |
class EmbeddingMetadataDataset(Dataset):
|
|
|
272 |
|
273 |
class FungiEnsembleModel(nn.Module):
|
274 |
|
275 |
+
def __init__(self, models) -> None:
|
276 |
super().__init__()
|
277 |
|
278 |
self.models = nn.ModuleList()
|
|
|
279 |
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
280 |
|
281 |
for model in models:
|
|
|
292 |
for model in self.models:
|
293 |
logits = model.forward(img_emb, metadata)
|
294 |
|
295 |
+
p = logits.softmax(dim=1).detach().cpu()
|
|
|
|
|
|
|
|
|
|
|
296 |
probs.append(p)
|
297 |
|
298 |
return torch.stack(probs).mean(dim=0)
|
|
|
310 |
OUTPUT_CSV_PATH = "./submission.csv"
|
311 |
BASE_CKPT_PATH = "./checkpoints"
|
312 |
|
313 |
+
ckpt_path = os.path.join(BASE_CKPT_PATH, "dino_2_optuna_05242055.ckpt")
|
314 |
+
# model_names = os.listdir(BASE_CKPT_PATH)
|
315 |
|
316 |
+
# models = []
|
317 |
|
318 |
+
# for model_path in model_names:
|
319 |
+
# print("loading ", model_path)
|
320 |
+
# ckpt_path = os.path.join(BASE_CKPT_PATH, model_path)
|
321 |
|
322 |
+
# ckpt = torch.load(ckpt_path)
|
323 |
+
# model = FungiMEEModel()
|
324 |
+
# model.load_state_dict(
|
325 |
+
# {w: ckpt["model." + w] for w in model.state_dict().keys()}
|
326 |
+
# )
|
327 |
+
# model.eval()
|
328 |
+
# model.cuda()
|
329 |
|
330 |
+
# models.append(model)
|
331 |
|
332 |
+
# fungi_model = FungiEnsembleModel(models)
|
333 |
+
|
334 |
+
fungi_model = FungiMEEModel()
|
335 |
+
ckpt = torch.load(ckpt_path)
|
336 |
+
fungi_model.load_state_dict(
|
337 |
+
{w: ckpt["model." + w] for w in fungi_model.state_dict().keys()}
|
338 |
+
)
|
339 |
|
340 |
embedding_dataset = EmbeddingMetadataDataset(metadata_df)
|
341 |
loader = DataLoader(embedding_dataset, batch_size=128, shuffle=False)
|
|
|
343 |
preds = []
|
344 |
for data in tqdm(loader):
|
345 |
emb, metadata = data
|
346 |
+
pred = fungi_model.forward(emb, metadata)
|
347 |
preds.append(pred)
|
348 |
|
349 |
all_preds = torch.vstack(preds).numpy()
|
|
|
366 |
|
367 |
if __name__ == "__main__":
|
368 |
|
369 |
+
MODEL_PATH = "metaformer-s-224.pth"
|
370 |
+
MODEL_NAME = "timm/vit_base_patch14_reg4_dinov2.lvd142m"
|
371 |
+
|
372 |
+
# # # # Real submission
|
373 |
+
# import zipfile
|
374 |
|
375 |
+
# with zipfile.ZipFile("/tmp/data/private_testset.zip", "r") as zip_ref:
|
376 |
+
# zip_ref.extractall("/tmp/data")
|
377 |
|
378 |
+
# metadata_file_path = "./_test_preprocessed.csv"
|
379 |
+
# root_dir = "/tmp/data"
|
380 |
|
381 |
# Test submission
|
382 |
+
metadata_file_path = "../trial_submission.csv"
|
383 |
+
root_dir = "../data/DF_FULL"
|
384 |
|
385 |
##############
|
386 |
|