fixed script
Browse files
script.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import io
|
2 |
import os
|
3 |
from typing import List
|
4 |
|
@@ -10,9 +9,18 @@ import torch
|
|
10 |
import torch.nn as nn
|
11 |
import torch.nn.functional as F
|
12 |
import torchvision.transforms as T
|
13 |
-
from albumentations import (
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
from albumentations.pytorch import ToTensorV2
|
17 |
from PIL import Image
|
18 |
from timm.layers import LayerNorm2d, SelectAdaptivePool2d
|
@@ -20,14 +28,13 @@ from timm.models.metaformer import MlpHead
|
|
20 |
from torch.utils.data import DataLoader, Dataset
|
21 |
from tqdm import tqdm
|
22 |
|
23 |
-
|
24 |
-
DEFAULT_HEIGHT = 518
|
25 |
|
26 |
def get_transforms(*, data, model=None, width=None, height=None):
|
27 |
assert data in ("train", "valid")
|
28 |
|
29 |
-
width = width if width else
|
30 |
-
height = height if height else
|
31 |
|
32 |
model_mean = list(model.default_cfg["mean"]) if model else (0.5, 0.5, 0.5)
|
33 |
model_std = list(model.default_cfg["std"]) if model else (0.5, 0.5, 0.5)
|
@@ -53,8 +60,6 @@ def get_transforms(*, data, model=None, width=None, height=None):
|
|
53 |
]
|
54 |
)
|
55 |
|
56 |
-
DIM = 518
|
57 |
-
BASE_PATH = "../data/DF_FULL"
|
58 |
|
59 |
def generate_embeddings(metadata_file_path, root_dir):
|
60 |
|
@@ -69,7 +74,9 @@ def generate_embeddings(metadata_file_path, root_dir):
|
|
69 |
loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)
|
70 |
|
71 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
72 |
-
model = timm.create_model(
|
|
|
|
|
73 |
model = model.to(device)
|
74 |
model.eval()
|
75 |
|
@@ -87,13 +94,14 @@ def generate_embeddings(metadata_file_path, root_dir):
|
|
87 |
|
88 |
embs_list = [x for x in all_embs]
|
89 |
metadata_df["embedding"] = embs_list
|
90 |
-
|
91 |
return metadata_df
|
92 |
|
93 |
|
94 |
-
TIME = [
|
95 |
-
GEO = [
|
96 |
-
SUBSTRATE = [
|
|
|
97 |
"substrate_1",
|
98 |
"substrate_2",
|
99 |
"substrate_3",
|
@@ -168,11 +176,12 @@ SUBSTRATE = ["substrate_0",
|
|
168 |
"habitat_31",
|
169 |
]
|
170 |
|
|
|
171 |
class EmbeddingMetadataDataset(Dataset):
|
172 |
def __init__(self, df):
|
173 |
self.df = df
|
174 |
|
175 |
-
self.emb = df[
|
176 |
self.metadata_date = df[TIME].to_numpy()
|
177 |
self.metadata_geo = df[GEO].to_numpy()
|
178 |
self.metadata_substrate = df[SUBSTRATE].to_numpy()
|
@@ -186,24 +195,27 @@ class EmbeddingMetadataDataset(Dataset):
|
|
186 |
metadata = {
|
187 |
"date": torch.from_numpy(self.metadata_date[idx, :]).type(torch.float),
|
188 |
"geo": torch.from_numpy(self.metadata_geo[idx, :]).type(torch.float),
|
189 |
-
"substr": torch.from_numpy(self.metadata_substrate[idx, :]).type(
|
|
|
|
|
190 |
}
|
191 |
|
192 |
return embedding, metadata
|
193 |
-
|
194 |
|
195 |
class ImageMetadataDataset(Dataset):
|
196 |
def __init__(self, df, transform=None, local_filepath=None):
|
197 |
self.df = df
|
198 |
self.transform = transform
|
199 |
self.local_filepath = local_filepath
|
200 |
-
|
201 |
-
self.filepaths =
|
|
|
|
|
202 |
self.metadata_date = df[TIME].to_numpy()
|
203 |
self.metadata_geo = df[GEO].to_numpy()
|
204 |
self.metadata_substrate = df[SUBSTRATE].to_numpy()
|
205 |
|
206 |
-
|
207 |
def __len__(self):
|
208 |
return len(self.df)
|
209 |
|
@@ -223,16 +235,20 @@ class ImageMetadataDataset(Dataset):
|
|
223 |
metadata = {
|
224 |
"date": torch.from_numpy(self.metadata_date[idx, :]).type(torch.float),
|
225 |
"geo": torch.from_numpy(self.metadata_geo[idx, :]).type(torch.float),
|
226 |
-
"substr": torch.from_numpy(self.metadata_substrate[idx, :]).type(
|
|
|
|
|
227 |
}
|
228 |
|
229 |
return image, metadata
|
230 |
|
|
|
231 |
DATE_SIZE = 4
|
232 |
GEO_SIZE = 7
|
233 |
SUBSTRATE_SIZE = 73
|
234 |
NUM_CLASSES = 1717
|
235 |
|
|
|
236 |
class StarReLU(nn.Module):
|
237 |
"""
|
238 |
StarReLU: s * relu(x) ** 2 + b
|
@@ -260,6 +276,7 @@ class StarReLU(nn.Module):
|
|
260 |
def forward(self, x):
|
261 |
return self.scale * self.relu(x) ** 2 + self.bias
|
262 |
|
|
|
263 |
class FungiMEEModel(nn.Module):
|
264 |
def __init__(
|
265 |
self,
|
@@ -272,7 +289,6 @@ class FungiMEEModel(nn.Module):
|
|
272 |
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
273 |
print(f"Using devide: {self.device}")
|
274 |
|
275 |
-
|
276 |
self.date_embedding = MlpHead(
|
277 |
dim=DATE_SIZE, num_classes=dim, mlp_ratio=128, act_layer=StarReLU
|
278 |
)
|
@@ -286,38 +302,43 @@ class FungiMEEModel(nn.Module):
|
|
286 |
act_layer=StarReLU,
|
287 |
)
|
288 |
|
289 |
-
self.encoder = nn.TransformerEncoder(
|
290 |
-
|
|
|
|
|
|
|
291 |
self.head = MlpHead(dim=dim, num_classes=num_classes, drop_rate=0)
|
292 |
|
293 |
for param in self.parameters():
|
294 |
if param.dim() > 1:
|
295 |
nn.init.kaiming_normal_(param)
|
296 |
|
297 |
-
|
298 |
def forward(self, img_emb, metadata):
|
299 |
|
300 |
img_emb = img_emb.to(self.device)
|
301 |
-
|
302 |
date_emb = self.date_embedding.forward(metadata["date"].to(self.device))
|
303 |
geo_emb = self.geo_embedding.forward(metadata["geo"].to(self.device))
|
304 |
substr_emb = self.substr_embedding.forward(metadata["substr"].to(self.device))
|
305 |
|
306 |
-
full_emb = torch.stack(
|
|
|
|
|
307 |
# print(full_emb.shape)
|
308 |
|
309 |
cls_emb = self.encoder.forward(full_emb)[:, 0, :].squeeze(1)
|
310 |
|
311 |
return self.head.forward(cls_emb)
|
312 |
-
|
313 |
def predict(self, img_emb, metadata):
|
314 |
-
|
315 |
logits = self.forward(img_emb, metadata)
|
316 |
|
317 |
# Any preprocess happens here
|
318 |
|
319 |
return logits.argmax(1).tolist()
|
320 |
-
|
|
|
321 |
class FungiEnsembleModel(nn.Module):
|
322 |
|
323 |
def __init__(self, models, softmax=True) -> None:
|
@@ -331,39 +352,46 @@ class FungiEnsembleModel(nn.Module):
|
|
331 |
model = model.to(self.device)
|
332 |
model.eval()
|
333 |
self.models.append(model)
|
334 |
-
|
335 |
def forward(self, img_emb, metadata):
|
336 |
|
337 |
img_emb = img_emb.to(self.device)
|
338 |
|
339 |
-
probs = []
|
340 |
|
341 |
for model in self.models:
|
342 |
logits = model.forward(img_emb, metadata)
|
343 |
-
|
344 |
-
p =
|
|
|
|
|
|
|
|
|
345 |
|
346 |
probs.append(p)
|
347 |
|
348 |
return torch.stack(probs).mean(dim=0)
|
349 |
-
|
350 |
def predict(self, img_emb, metadata):
|
351 |
-
|
352 |
logits = self.forward(img_emb, metadata)
|
353 |
|
354 |
# Any preprocess happens here
|
355 |
|
356 |
return logits.argmax(1).tolist()
|
357 |
-
|
358 |
|
359 |
def is_gpu_available():
|
360 |
"""Check if the python package `onnxruntime-gpu` is installed."""
|
361 |
return torch.cuda.is_available()
|
362 |
|
|
|
363 |
class PytorchWorker:
|
364 |
"""Run inference using ONNX runtime."""
|
365 |
|
366 |
-
def __init__(
|
|
|
|
|
367 |
|
368 |
def _load_model(model_name, model_path):
|
369 |
|
@@ -379,10 +407,13 @@ class PytorchWorker:
|
|
379 |
|
380 |
self.model = _load_model(model_name, model_path)
|
381 |
|
382 |
-
self.transforms = T.Compose(
|
383 |
-
|
384 |
-
|
385 |
-
|
|
|
|
|
|
|
386 |
|
387 |
def predict_image(self, image: np.ndarray):
|
388 |
"""Run inference using ONNX runtime.
|
@@ -397,9 +428,9 @@ class PytorchWorker:
|
|
397 |
|
398 |
|
399 |
def make_submission(metadata_df, model_names=None):
|
400 |
-
|
401 |
-
OUTPUT_CSV_PATH="./submission.csv"
|
402 |
-
|
403 |
"""Make submission with given """
|
404 |
|
405 |
BASE_CKPT_PATH = "./checkpoints"
|
@@ -414,12 +445,14 @@ def make_submission(metadata_df, model_names=None):
|
|
414 |
|
415 |
ckpt = torch.load(ckpt_path)
|
416 |
model = FungiMEEModel()
|
417 |
-
model.load_state_dict(
|
|
|
|
|
418 |
model.eval()
|
419 |
model.cuda()
|
420 |
|
421 |
models.append(model)
|
422 |
-
|
423 |
ensemble_model = FungiEnsembleModel(models)
|
424 |
|
425 |
embedding_dataset = EmbeddingMetadataDataset(metadata_df)
|
@@ -433,31 +466,39 @@ def make_submission(metadata_df, model_names=None):
|
|
433 |
|
434 |
all_preds = torch.vstack(preds).numpy()
|
435 |
|
436 |
-
preds_df = metadata_df[[
|
437 |
-
preds_df[
|
438 |
-
preds_df =
|
439 |
-
|
440 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
441 |
|
442 |
print("Submission complete")
|
443 |
|
|
|
444 |
if __name__ == "__main__":
|
445 |
|
446 |
MODEL_PATH = "metaformer-s-224.pth"
|
447 |
MODEL_NAME = "timm/vit_base_patch14_reg4_dinov2.lvd142m"
|
448 |
|
449 |
# # Real submission
|
450 |
-
import zipfile
|
451 |
|
452 |
-
with zipfile.ZipFile("/tmp/data/private_testset.zip",
|
453 |
-
|
454 |
|
455 |
-
metadata_file_path = "./_test_preprocessed.csv"
|
456 |
-
root_dir = "/tmp/data"
|
457 |
|
458 |
# Test submission
|
459 |
-
|
460 |
-
|
461 |
|
462 |
##############
|
463 |
|
|
|
|
|
1 |
import os
|
2 |
from typing import List
|
3 |
|
|
|
9 |
import torch.nn as nn
|
10 |
import torch.nn.functional as F
|
11 |
import torchvision.transforms as T
|
12 |
+
from albumentations import (
|
13 |
+
CenterCrop,
|
14 |
+
Compose,
|
15 |
+
HorizontalFlip,
|
16 |
+
Normalize,
|
17 |
+
PadIfNeeded,
|
18 |
+
RandomBrightnessContrast,
|
19 |
+
RandomCrop,
|
20 |
+
RandomResizedCrop,
|
21 |
+
Resize,
|
22 |
+
VerticalFlip,
|
23 |
+
)
|
24 |
from albumentations.pytorch import ToTensorV2
|
25 |
from PIL import Image
|
26 |
from timm.layers import LayerNorm2d, SelectAdaptivePool2d
|
|
|
28 |
from torch.utils.data import DataLoader, Dataset
|
29 |
from tqdm import tqdm
|
30 |
|
31 |
+
DIM = 518
|
|
|
32 |
|
33 |
def get_transforms(*, data, model=None, width=None, height=None):
|
34 |
assert data in ("train", "valid")
|
35 |
|
36 |
+
width = width if width else DIM
|
37 |
+
height = height if height else DIM
|
38 |
|
39 |
model_mean = list(model.default_cfg["mean"]) if model else (0.5, 0.5, 0.5)
|
40 |
model_std = list(model.default_cfg["std"]) if model else (0.5, 0.5, 0.5)
|
|
|
60 |
]
|
61 |
)
|
62 |
|
|
|
|
|
63 |
|
64 |
def generate_embeddings(metadata_file_path, root_dir):
|
65 |
|
|
|
74 |
loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)
|
75 |
|
76 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
77 |
+
model = timm.create_model(
|
78 |
+
"timm/vit_large_patch14_reg4_dinov2.lvd142m", pretrained=True
|
79 |
+
)
|
80 |
model = model.to(device)
|
81 |
model.eval()
|
82 |
|
|
|
94 |
|
95 |
embs_list = [x for x in all_embs]
|
96 |
metadata_df["embedding"] = embs_list
|
97 |
+
|
98 |
return metadata_df
|
99 |
|
100 |
|
101 |
+
TIME = ["m0", "m1", "d0", "d1"]
|
102 |
+
GEO = ["g0", "g1", "g2", "g3", "g4", "g5", "g_float"]
|
103 |
+
SUBSTRATE = [
|
104 |
+
"substrate_0",
|
105 |
"substrate_1",
|
106 |
"substrate_2",
|
107 |
"substrate_3",
|
|
|
176 |
"habitat_31",
|
177 |
]
|
178 |
|
179 |
+
|
180 |
class EmbeddingMetadataDataset(Dataset):
|
181 |
def __init__(self, df):
|
182 |
self.df = df
|
183 |
|
184 |
+
self.emb = df["embedding"]
|
185 |
self.metadata_date = df[TIME].to_numpy()
|
186 |
self.metadata_geo = df[GEO].to_numpy()
|
187 |
self.metadata_substrate = df[SUBSTRATE].to_numpy()
|
|
|
195 |
metadata = {
|
196 |
"date": torch.from_numpy(self.metadata_date[idx, :]).type(torch.float),
|
197 |
"geo": torch.from_numpy(self.metadata_geo[idx, :]).type(torch.float),
|
198 |
+
"substr": torch.from_numpy(self.metadata_substrate[idx, :]).type(
|
199 |
+
torch.float
|
200 |
+
),
|
201 |
}
|
202 |
|
203 |
return embedding, metadata
|
204 |
+
|
205 |
|
206 |
class ImageMetadataDataset(Dataset):
|
207 |
def __init__(self, df, transform=None, local_filepath=None):
|
208 |
self.df = df
|
209 |
self.transform = transform
|
210 |
self.local_filepath = local_filepath
|
211 |
+
|
212 |
+
self.filepaths = (
|
213 |
+
df["image_path"].to_list()
|
214 |
+
)
|
215 |
self.metadata_date = df[TIME].to_numpy()
|
216 |
self.metadata_geo = df[GEO].to_numpy()
|
217 |
self.metadata_substrate = df[SUBSTRATE].to_numpy()
|
218 |
|
|
|
219 |
def __len__(self):
|
220 |
return len(self.df)
|
221 |
|
|
|
235 |
metadata = {
|
236 |
"date": torch.from_numpy(self.metadata_date[idx, :]).type(torch.float),
|
237 |
"geo": torch.from_numpy(self.metadata_geo[idx, :]).type(torch.float),
|
238 |
+
"substr": torch.from_numpy(self.metadata_substrate[idx, :]).type(
|
239 |
+
torch.float
|
240 |
+
),
|
241 |
}
|
242 |
|
243 |
return image, metadata
|
244 |
|
245 |
+
|
246 |
DATE_SIZE = 4
|
247 |
GEO_SIZE = 7
|
248 |
SUBSTRATE_SIZE = 73
|
249 |
NUM_CLASSES = 1717
|
250 |
|
251 |
+
|
252 |
class StarReLU(nn.Module):
|
253 |
"""
|
254 |
StarReLU: s * relu(x) ** 2 + b
|
|
|
276 |
def forward(self, x):
|
277 |
return self.scale * self.relu(x) ** 2 + self.bias
|
278 |
|
279 |
+
|
280 |
class FungiMEEModel(nn.Module):
|
281 |
def __init__(
|
282 |
self,
|
|
|
289 |
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
290 |
print(f"Using devide: {self.device}")
|
291 |
|
|
|
292 |
self.date_embedding = MlpHead(
|
293 |
dim=DATE_SIZE, num_classes=dim, mlp_ratio=128, act_layer=StarReLU
|
294 |
)
|
|
|
302 |
act_layer=StarReLU,
|
303 |
)
|
304 |
|
305 |
+
self.encoder = nn.TransformerEncoder(
|
306 |
+
nn.TransformerEncoderLayer(d_model=dim, nhead=8, batch_first=True),
|
307 |
+
num_layers=4,
|
308 |
+
)
|
309 |
+
|
310 |
self.head = MlpHead(dim=dim, num_classes=num_classes, drop_rate=0)
|
311 |
|
312 |
for param in self.parameters():
|
313 |
if param.dim() > 1:
|
314 |
nn.init.kaiming_normal_(param)
|
315 |
|
|
|
316 |
def forward(self, img_emb, metadata):
|
317 |
|
318 |
img_emb = img_emb.to(self.device)
|
319 |
+
|
320 |
date_emb = self.date_embedding.forward(metadata["date"].to(self.device))
|
321 |
geo_emb = self.geo_embedding.forward(metadata["geo"].to(self.device))
|
322 |
substr_emb = self.substr_embedding.forward(metadata["substr"].to(self.device))
|
323 |
|
324 |
+
full_emb = torch.stack(
|
325 |
+
(img_emb, date_emb, geo_emb, substr_emb), dim=1
|
326 |
+
) # .unsqueeze(0)
|
327 |
# print(full_emb.shape)
|
328 |
|
329 |
cls_emb = self.encoder.forward(full_emb)[:, 0, :].squeeze(1)
|
330 |
|
331 |
return self.head.forward(cls_emb)
|
332 |
+
|
333 |
def predict(self, img_emb, metadata):
|
334 |
+
|
335 |
logits = self.forward(img_emb, metadata)
|
336 |
|
337 |
# Any preprocess happens here
|
338 |
|
339 |
return logits.argmax(1).tolist()
|
340 |
+
|
341 |
+
|
342 |
class FungiEnsembleModel(nn.Module):
|
343 |
|
344 |
def __init__(self, models, softmax=True) -> None:
|
|
|
352 |
model = model.to(self.device)
|
353 |
model.eval()
|
354 |
self.models.append(model)
|
355 |
+
|
356 |
def forward(self, img_emb, metadata):
|
357 |
|
358 |
img_emb = img_emb.to(self.device)
|
359 |
|
360 |
+
probs = []
|
361 |
|
362 |
for model in self.models:
|
363 |
logits = model.forward(img_emb, metadata)
|
364 |
+
|
365 |
+
p = (
|
366 |
+
logits.softmax(dim=1).detach().cpu()
|
367 |
+
if self.softmax
|
368 |
+
else logits.detach().cpu()
|
369 |
+
)
|
370 |
|
371 |
probs.append(p)
|
372 |
|
373 |
return torch.stack(probs).mean(dim=0)
|
374 |
+
|
375 |
def predict(self, img_emb, metadata):
|
376 |
+
|
377 |
logits = self.forward(img_emb, metadata)
|
378 |
|
379 |
# Any preprocess happens here
|
380 |
|
381 |
return logits.argmax(1).tolist()
|
382 |
+
|
383 |
|
384 |
def is_gpu_available():
|
385 |
"""Check if the python package `onnxruntime-gpu` is installed."""
|
386 |
return torch.cuda.is_available()
|
387 |
|
388 |
+
|
389 |
class PytorchWorker:
|
390 |
"""Run inference using ONNX runtime."""
|
391 |
|
392 |
+
def __init__(
|
393 |
+
self, model_path: str, model_name: str, number_of_categories: int = 1605
|
394 |
+
):
|
395 |
|
396 |
def _load_model(model_name, model_path):
|
397 |
|
|
|
407 |
|
408 |
self.model = _load_model(model_name, model_path)
|
409 |
|
410 |
+
self.transforms = T.Compose(
|
411 |
+
[
|
412 |
+
T.Resize((518, 518)),
|
413 |
+
T.ToTensor(),
|
414 |
+
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
415 |
+
]
|
416 |
+
)
|
417 |
|
418 |
def predict_image(self, image: np.ndarray):
|
419 |
"""Run inference using ONNX runtime.
|
|
|
428 |
|
429 |
|
430 |
def make_submission(metadata_df, model_names=None):
|
431 |
+
|
432 |
+
OUTPUT_CSV_PATH = "./submission.csv"
|
433 |
+
|
434 |
"""Make submission with given """
|
435 |
|
436 |
BASE_CKPT_PATH = "./checkpoints"
|
|
|
445 |
|
446 |
ckpt = torch.load(ckpt_path)
|
447 |
model = FungiMEEModel()
|
448 |
+
model.load_state_dict(
|
449 |
+
{w: ckpt["state_dict"]["model." + w] for w in model.state_dict().keys()}
|
450 |
+
)
|
451 |
model.eval()
|
452 |
model.cuda()
|
453 |
|
454 |
models.append(model)
|
455 |
+
|
456 |
ensemble_model = FungiEnsembleModel(models)
|
457 |
|
458 |
embedding_dataset = EmbeddingMetadataDataset(metadata_df)
|
|
|
466 |
|
467 |
all_preds = torch.vstack(preds).numpy()
|
468 |
|
469 |
+
preds_df = metadata_df[["observation_id", "image_path"]]
|
470 |
+
preds_df["preds"] = [i for i in all_preds]
|
471 |
+
preds_df = (
|
472 |
+
preds_df[["observation_id", "preds"]]
|
473 |
+
.groupby("observation_id")
|
474 |
+
.mean()
|
475 |
+
.reset_index()
|
476 |
+
)
|
477 |
+
preds_df["class_id"] = preds_df["preds"].apply(
|
478 |
+
lambda x: x.argmax() if x.argmax() <= 1603 else -1
|
479 |
+
)
|
480 |
+
preds_df[["observation_id", "class_id"]].to_csv(OUTPUT_CSV_PATH, index=None)
|
481 |
|
482 |
print("Submission complete")
|
483 |
|
484 |
+
|
485 |
if __name__ == "__main__":
|
486 |
|
487 |
MODEL_PATH = "metaformer-s-224.pth"
|
488 |
MODEL_NAME = "timm/vit_base_patch14_reg4_dinov2.lvd142m"
|
489 |
|
490 |
# # Real submission
|
491 |
+
# import zipfile
|
492 |
|
493 |
+
# with zipfile.ZipFile("/tmp/data/private_testset.zip", "r") as zip_ref:
|
494 |
+
# zip_ref.extractall("/tmp/data")
|
495 |
|
496 |
+
# metadata_file_path = "./_test_preprocessed.csv"
|
497 |
+
# root_dir = "/tmp/data"
|
498 |
|
499 |
# Test submission
|
500 |
+
metadata_file_path = "../trial_submission.csv"
|
501 |
+
root_dir = "../data/DF_FULL"
|
502 |
|
503 |
##############
|
504 |
|