fixed ckpts
Browse files- checkpoints/dino_2_optuna_05242055.ckpt +2 -2
- checkpoints/dino_2_optuna_05242156.ckpt +2 -2
- checkpoints/dino_2_optuna_05242231.ckpt +2 -2
- checkpoints/dino_2_optuna_05242344.ckpt +2 -2
- checkpoints/dino_optuna_05241222.ckpt +2 -2
- checkpoints/dino_optuna_05241257.ckpt +2 -2
- checkpoints/dino_optuna_05241449.ckpt +2 -2
- script.py +64 -179
checkpoints/dino_2_optuna_05242055.ckpt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bd95a08e0e7a725425d91810db61fc3a1167abe59ddf7ceedd067304dfc8b097
|
3 |
+
size 187793106
|
checkpoints/dino_2_optuna_05242156.ckpt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:403a4f866df7f3bc5358c04a4f8e188e49cab90d84e7812a9261abfc9fab3bc3
|
3 |
+
size 187793106
|
checkpoints/dino_2_optuna_05242231.ckpt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1bca4793b9032935502276fb4e2c045a661cd7b158e2633f0f27bc5d017e4141
|
3 |
+
size 187793106
|
checkpoints/dino_2_optuna_05242344.ckpt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:029202d06bad69443016b4b61149a7725dae9d1ba6faac04089c63ac01040bf7
|
3 |
+
size 187793106
|
checkpoints/dino_optuna_05241222.ckpt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5585d99338a3911e9c133e829c6a1218c9f18fc6cb7daeffbfc9fdf669c92c86
|
3 |
+
size 187792874
|
checkpoints/dino_optuna_05241257.ckpt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:885543336dd814a2bb598d96ca99a01564e4d7eca69778ceabe307ecb14a6d89
|
3 |
+
size 187792874
|
checkpoints/dino_optuna_05241449.ckpt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ca08ea13541b791d1d46f567727cccc4b1e9824064694803bf92473fc79f4192
|
3 |
+
size 187792874
|
script.py
CHANGED
@@ -1,7 +1,4 @@
|
|
1 |
import os
|
2 |
-
from typing import List
|
3 |
-
|
4 |
-
import cv2
|
5 |
import numpy as np
|
6 |
import pandas as pd
|
7 |
import timm
|
@@ -9,94 +6,16 @@ import torch
|
|
9 |
import torch.nn as nn
|
10 |
import torch.nn.functional as F
|
11 |
import torchvision.transforms as T
|
12 |
-
from albumentations import (
|
13 |
-
CenterCrop,
|
14 |
-
Compose,
|
15 |
-
HorizontalFlip,
|
16 |
-
Normalize,
|
17 |
-
PadIfNeeded,
|
18 |
-
RandomBrightnessContrast,
|
19 |
-
RandomCrop,
|
20 |
-
RandomResizedCrop,
|
21 |
-
Resize,
|
22 |
-
VerticalFlip,
|
23 |
-
)
|
24 |
-
from albumentations.pytorch import ToTensorV2
|
25 |
from PIL import Image
|
26 |
-
from timm.layers import LayerNorm2d, SelectAdaptivePool2d
|
27 |
from timm.models.metaformer import MlpHead
|
28 |
from torch.utils.data import DataLoader, Dataset
|
29 |
from tqdm import tqdm
|
30 |
|
31 |
DIM = 518
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
width = width if width else DIM
|
37 |
-
height = height if height else DIM
|
38 |
-
|
39 |
-
model_mean = list(model.default_cfg["mean"]) if model else (0.5, 0.5, 0.5)
|
40 |
-
model_std = list(model.default_cfg["std"]) if model else (0.5, 0.5, 0.5)
|
41 |
-
|
42 |
-
if data == "train":
|
43 |
-
return Compose(
|
44 |
-
[
|
45 |
-
RandomResizedCrop(width, height, scale=(0.6, 1.0)),
|
46 |
-
HorizontalFlip(p=0.5),
|
47 |
-
VerticalFlip(p=0.5),
|
48 |
-
RandomBrightnessContrast(p=0.2),
|
49 |
-
Normalize(mean=model_mean, std=model_std),
|
50 |
-
ToTensorV2(),
|
51 |
-
]
|
52 |
-
)
|
53 |
-
|
54 |
-
elif data == "valid":
|
55 |
-
return Compose(
|
56 |
-
[
|
57 |
-
Resize(width, height),
|
58 |
-
Normalize(mean=model_mean, std=model_std),
|
59 |
-
ToTensorV2(),
|
60 |
-
]
|
61 |
-
)
|
62 |
-
|
63 |
-
|
64 |
-
def generate_embeddings(metadata_file_path, root_dir):
|
65 |
-
|
66 |
-
metadata_df = pd.read_csv(metadata_file_path)
|
67 |
-
|
68 |
-
transforms = get_transforms(data="valid", width=DIM, height=DIM)
|
69 |
-
|
70 |
-
test_dataset = ImageMetadataDataset(
|
71 |
-
metadata_df, local_filepath=root_dir, transform=transforms
|
72 |
-
)
|
73 |
-
|
74 |
-
loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)
|
75 |
-
|
76 |
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
77 |
-
model = timm.create_model(
|
78 |
-
"timm/vit_large_patch14_reg4_dinov2.lvd142m", pretrained=True
|
79 |
-
)
|
80 |
-
model = model.to(device)
|
81 |
-
model.eval()
|
82 |
-
|
83 |
-
all_embs = []
|
84 |
-
for data in tqdm(loader):
|
85 |
-
|
86 |
-
img, _ = data
|
87 |
-
img = img.to(device)
|
88 |
-
|
89 |
-
emb = model.forward(img)
|
90 |
-
|
91 |
-
all_embs.append(emb.detach().cpu().numpy())
|
92 |
-
|
93 |
-
all_embs = np.vstack(all_embs)
|
94 |
-
|
95 |
-
embs_list = [x for x in all_embs]
|
96 |
-
metadata_df["embedding"] = embs_list
|
97 |
-
|
98 |
-
return metadata_df
|
99 |
-
|
100 |
|
101 |
TIME = ["m0", "m1", "d0", "d1"]
|
102 |
GEO = ["g0", "g1", "g2", "g3", "g4", "g5", "g_float"]
|
@@ -177,6 +96,32 @@ SUBSTRATE = [
|
|
177 |
]
|
178 |
|
179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
class EmbeddingMetadataDataset(Dataset):
|
181 |
def __init__(self, df):
|
182 |
self.df = df
|
@@ -203,50 +148,37 @@ class EmbeddingMetadataDataset(Dataset):
|
|
203 |
return embedding, metadata
|
204 |
|
205 |
|
206 |
-
|
207 |
-
def __init__(self, df, transform=None, local_filepath=None):
|
208 |
-
self.df = df
|
209 |
-
self.transform = transform
|
210 |
-
self.local_filepath = local_filepath
|
211 |
|
212 |
-
|
213 |
-
df["image_path"].to_list()
|
214 |
-
)
|
215 |
-
self.metadata_date = df[TIME].to_numpy()
|
216 |
-
self.metadata_geo = df[GEO].to_numpy()
|
217 |
-
self.metadata_substrate = df[SUBSTRATE].to_numpy()
|
218 |
|
219 |
-
|
220 |
-
return len(self.df)
|
221 |
|
222 |
-
|
223 |
-
file_path = os.path.join(self.local_filepath, self.filepaths[idx])
|
224 |
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
print(file_path)
|
230 |
|
231 |
-
|
232 |
-
|
233 |
-
image = augmented["image"]
|
234 |
|
235 |
-
|
236 |
-
|
237 |
-
"geo": torch.from_numpy(self.metadata_geo[idx, :]).type(torch.float),
|
238 |
-
"substr": torch.from_numpy(self.metadata_substrate[idx, :]).type(
|
239 |
-
torch.float
|
240 |
-
),
|
241 |
-
}
|
242 |
|
243 |
-
|
244 |
|
|
|
245 |
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
|
|
|
|
|
|
|
|
250 |
|
251 |
|
252 |
class StarReLU(nn.Module):
|
@@ -323,8 +255,7 @@ class FungiMEEModel(nn.Module):
|
|
323 |
|
324 |
full_emb = torch.stack(
|
325 |
(img_emb, date_emb, geo_emb, substr_emb), dim=1
|
326 |
-
)
|
327 |
-
# print(full_emb.shape)
|
328 |
|
329 |
cls_emb = self.encoder.forward(full_emb)[:, 0, :].squeeze(1)
|
330 |
|
@@ -334,8 +265,6 @@ class FungiMEEModel(nn.Module):
|
|
334 |
|
335 |
logits = self.forward(img_emb, metadata)
|
336 |
|
337 |
-
# Any preprocess happens here
|
338 |
-
|
339 |
return logits.argmax(1).tolist()
|
340 |
|
341 |
|
@@ -386,56 +315,12 @@ def is_gpu_available():
|
|
386 |
return torch.cuda.is_available()
|
387 |
|
388 |
|
389 |
-
|
390 |
-
"""Run inference using ONNX runtime."""
|
391 |
-
|
392 |
-
def __init__(
|
393 |
-
self, model_path: str, model_name: str, number_of_categories: int = 1605
|
394 |
-
):
|
395 |
-
|
396 |
-
def _load_model(model_name, model_path):
|
397 |
-
|
398 |
-
print("Setting up Pytorch Model")
|
399 |
-
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
400 |
-
print(f"Using devide: {self.device}")
|
401 |
-
|
402 |
-
model = timm.create_model(model_name, num_classes=0, pretrained=False)
|
403 |
-
# weights = torch.load(model_path, map_location=self.device)
|
404 |
-
# model.load_state_dict({w.replace("model.", ""): v for w, v in weights.items()})
|
405 |
-
|
406 |
-
return model.to(self.device).eval()
|
407 |
-
|
408 |
-
self.model = _load_model(model_name, model_path)
|
409 |
-
|
410 |
-
self.transforms = T.Compose(
|
411 |
-
[
|
412 |
-
T.Resize((518, 518)),
|
413 |
-
T.ToTensor(),
|
414 |
-
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
415 |
-
]
|
416 |
-
)
|
417 |
-
|
418 |
-
def predict_image(self, image: np.ndarray):
|
419 |
-
"""Run inference using ONNX runtime.
|
420 |
-
|
421 |
-
:param image: Input image as numpy array.
|
422 |
-
:return: A list with logits and confidences.
|
423 |
-
"""
|
424 |
-
|
425 |
-
self.model(self.transforms(image).unsqueeze(0).to(self.device))
|
426 |
-
|
427 |
-
return [-1]
|
428 |
-
|
429 |
-
|
430 |
-
def make_submission(metadata_df, model_names=None):
|
431 |
|
432 |
OUTPUT_CSV_PATH = "./submission.csv"
|
433 |
-
|
434 |
-
"""Make submission with given """
|
435 |
-
|
436 |
BASE_CKPT_PATH = "./checkpoints"
|
437 |
|
438 |
-
model_names =
|
439 |
|
440 |
models = []
|
441 |
|
@@ -446,7 +331,7 @@ def make_submission(metadata_df, model_names=None):
|
|
446 |
ckpt = torch.load(ckpt_path)
|
447 |
model = FungiMEEModel()
|
448 |
model.load_state_dict(
|
449 |
-
{w: ckpt["
|
450 |
)
|
451 |
model.eval()
|
452 |
model.cuda()
|
@@ -487,18 +372,18 @@ if __name__ == "__main__":
|
|
487 |
MODEL_PATH = "metaformer-s-224.pth"
|
488 |
MODEL_NAME = "timm/vit_base_patch14_reg4_dinov2.lvd142m"
|
489 |
|
490 |
-
# # Real submission
|
491 |
-
import zipfile
|
492 |
|
493 |
-
with zipfile.ZipFile("/tmp/data/private_testset.zip", "r") as zip_ref:
|
494 |
-
|
495 |
|
496 |
-
metadata_file_path = "./_test_preprocessed.csv"
|
497 |
-
root_dir = "/tmp/data"
|
498 |
|
499 |
# Test submission
|
500 |
-
|
501 |
-
|
502 |
|
503 |
##############
|
504 |
|
|
|
1 |
import os
|
|
|
|
|
|
|
2 |
import numpy as np
|
3 |
import pandas as pd
|
4 |
import timm
|
|
|
6 |
import torch.nn as nn
|
7 |
import torch.nn.functional as F
|
8 |
import torchvision.transforms as T
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
from PIL import Image
|
|
|
10 |
from timm.models.metaformer import MlpHead
|
11 |
from torch.utils.data import DataLoader, Dataset
|
12 |
from tqdm import tqdm
|
13 |
|
14 |
DIM = 518
|
15 |
+
DATE_SIZE = 4
|
16 |
+
GEO_SIZE = 7
|
17 |
+
SUBSTRATE_SIZE = 73
|
18 |
+
NUM_CLASSES = 1717
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
TIME = ["m0", "m1", "d0", "d1"]
|
21 |
GEO = ["g0", "g1", "g2", "g3", "g4", "g5", "g_float"]
|
|
|
96 |
]
|
97 |
|
98 |
|
99 |
+
class ImageDataset(Dataset):
|
100 |
+
def __init__(self, df, local_filepath):
|
101 |
+
self.df = df
|
102 |
+
self.transform = T.Compose(
|
103 |
+
[
|
104 |
+
T.Resize((DIM, DIM)),
|
105 |
+
T.ToTensor(),
|
106 |
+
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
107 |
+
]
|
108 |
+
)
|
109 |
+
|
110 |
+
self.local_filepath = local_filepath
|
111 |
+
|
112 |
+
self.filepaths = df["image_path"].to_list()
|
113 |
+
|
114 |
+
def __len__(self):
|
115 |
+
return len(self.df)
|
116 |
+
|
117 |
+
def __getitem__(self, idx):
|
118 |
+
image_path = os.path.join(self.local_filepath, self.filepaths[idx])
|
119 |
+
|
120 |
+
image = Image.open(image_path).convert("RGB")
|
121 |
+
|
122 |
+
return self.transform(image)
|
123 |
+
|
124 |
+
|
125 |
class EmbeddingMetadataDataset(Dataset):
|
126 |
def __init__(self, df):
|
127 |
self.df = df
|
|
|
148 |
return embedding, metadata
|
149 |
|
150 |
|
151 |
+
def generate_embeddings(metadata_file_path, root_dir):
|
|
|
|
|
|
|
|
|
152 |
|
153 |
+
metadata_df = pd.read_csv(metadata_file_path)
|
|
|
|
|
|
|
|
|
|
|
154 |
|
155 |
+
test_dataset = ImageDataset(metadata_df, local_filepath=root_dir)
|
|
|
156 |
|
157 |
+
loader = DataLoader(test_dataset, batch_size=2, shuffle=False, num_workers=4)
|
|
|
158 |
|
159 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
160 |
+
model = timm.create_model(
|
161 |
+
"timm/vit_large_patch14_reg4_dinov2.lvd142m", pretrained=True
|
162 |
+
)
|
|
|
163 |
|
164 |
+
model = model.to(device)
|
165 |
+
model.eval()
|
|
|
166 |
|
167 |
+
all_embs = []
|
168 |
+
for img in tqdm(loader):
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
+
img = img.to(device)
|
171 |
|
172 |
+
emb = model.forward(img)
|
173 |
|
174 |
+
all_embs.append(emb.detach().cpu().numpy())
|
175 |
+
|
176 |
+
all_embs = np.vstack(all_embs)
|
177 |
+
|
178 |
+
embs_list = [x for x in all_embs]
|
179 |
+
metadata_df["embedding"] = embs_list
|
180 |
+
|
181 |
+
return metadata_df
|
182 |
|
183 |
|
184 |
class StarReLU(nn.Module):
|
|
|
255 |
|
256 |
full_emb = torch.stack(
|
257 |
(img_emb, date_emb, geo_emb, substr_emb), dim=1
|
258 |
+
)
|
|
|
259 |
|
260 |
cls_emb = self.encoder.forward(full_emb)[:, 0, :].squeeze(1)
|
261 |
|
|
|
265 |
|
266 |
logits = self.forward(img_emb, metadata)
|
267 |
|
|
|
|
|
268 |
return logits.argmax(1).tolist()
|
269 |
|
270 |
|
|
|
315 |
return torch.cuda.is_available()
|
316 |
|
317 |
|
318 |
+
def make_submission(metadata_df):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
319 |
|
320 |
OUTPUT_CSV_PATH = "./submission.csv"
|
|
|
|
|
|
|
321 |
BASE_CKPT_PATH = "./checkpoints"
|
322 |
|
323 |
+
model_names = os.listdir(BASE_CKPT_PATH)
|
324 |
|
325 |
models = []
|
326 |
|
|
|
331 |
ckpt = torch.load(ckpt_path)
|
332 |
model = FungiMEEModel()
|
333 |
model.load_state_dict(
|
334 |
+
{w: ckpt["model." + w] for w in model.state_dict().keys()}
|
335 |
)
|
336 |
model.eval()
|
337 |
model.cuda()
|
|
|
372 |
MODEL_PATH = "metaformer-s-224.pth"
|
373 |
MODEL_NAME = "timm/vit_base_patch14_reg4_dinov2.lvd142m"
|
374 |
|
375 |
+
# # # Real submission
|
376 |
+
# import zipfile
|
377 |
|
378 |
+
# with zipfile.ZipFile("/tmp/data/private_testset.zip", "r") as zip_ref:
|
379 |
+
# zip_ref.extractall("/tmp/data")
|
380 |
|
381 |
+
# metadata_file_path = "./_test_preprocessed.csv"
|
382 |
+
# root_dir = "/tmp/data"
|
383 |
|
384 |
# Test submission
|
385 |
+
metadata_file_path = "../trial_submission.csv"
|
386 |
+
root_dir = "../data/DF_FULL"
|
387 |
|
388 |
##############
|
389 |
|