chychiu commited on
Commit
8f1fb11
·
1 Parent(s): a5f776c

fixed script

Browse files
Files changed (1) hide show
  1. script.py +101 -60
script.py CHANGED
@@ -1,4 +1,3 @@
1
- import io
2
  import os
3
  from typing import List
4
 
@@ -10,9 +9,18 @@ import torch
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
  import torchvision.transforms as T
13
- from albumentations import (CenterCrop, Compose, HorizontalFlip, Normalize,
14
- PadIfNeeded, RandomBrightnessContrast, RandomCrop,
15
- RandomResizedCrop, Resize, VerticalFlip)
 
 
 
 
 
 
 
 
 
16
  from albumentations.pytorch import ToTensorV2
17
  from PIL import Image
18
  from timm.layers import LayerNorm2d, SelectAdaptivePool2d
@@ -20,14 +28,13 @@ from timm.models.metaformer import MlpHead
20
  from torch.utils.data import DataLoader, Dataset
21
  from tqdm import tqdm
22
 
23
- DEFAULT_WIDTH = 518
24
- DEFAULT_HEIGHT = 518
25
 
26
  def get_transforms(*, data, model=None, width=None, height=None):
27
  assert data in ("train", "valid")
28
 
29
- width = width if width else DEFAULT_WIDTH
30
- height = height if height else DEFAULT_HEIGHT
31
 
32
  model_mean = list(model.default_cfg["mean"]) if model else (0.5, 0.5, 0.5)
33
  model_std = list(model.default_cfg["std"]) if model else (0.5, 0.5, 0.5)
@@ -53,8 +60,6 @@ def get_transforms(*, data, model=None, width=None, height=None):
53
  ]
54
  )
55
 
56
- DIM = 518
57
- BASE_PATH = "../data/DF_FULL"
58
 
59
  def generate_embeddings(metadata_file_path, root_dir):
60
 
@@ -69,7 +74,9 @@ def generate_embeddings(metadata_file_path, root_dir):
69
  loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)
70
 
71
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
72
- model = timm.create_model("timm/vit_large_patch14_reg4_dinov2.lvd142m", pretrained=True)
 
 
73
  model = model.to(device)
74
  model.eval()
75
 
@@ -87,13 +94,14 @@ def generate_embeddings(metadata_file_path, root_dir):
87
 
88
  embs_list = [x for x in all_embs]
89
  metadata_df["embedding"] = embs_list
90
-
91
  return metadata_df
92
 
93
 
94
- TIME = ['m0', 'm1', 'd0', 'd1']
95
- GEO = ['g0', 'g1', 'g2', 'g3', 'g4', 'g5', 'g_float']
96
- SUBSTRATE = ["substrate_0",
 
97
  "substrate_1",
98
  "substrate_2",
99
  "substrate_3",
@@ -168,11 +176,12 @@ SUBSTRATE = ["substrate_0",
168
  "habitat_31",
169
  ]
170
 
 
171
  class EmbeddingMetadataDataset(Dataset):
172
  def __init__(self, df):
173
  self.df = df
174
 
175
- self.emb = df['embedding']
176
  self.metadata_date = df[TIME].to_numpy()
177
  self.metadata_geo = df[GEO].to_numpy()
178
  self.metadata_substrate = df[SUBSTRATE].to_numpy()
@@ -186,24 +195,27 @@ class EmbeddingMetadataDataset(Dataset):
186
  metadata = {
187
  "date": torch.from_numpy(self.metadata_date[idx, :]).type(torch.float),
188
  "geo": torch.from_numpy(self.metadata_geo[idx, :]).type(torch.float),
189
- "substr": torch.from_numpy(self.metadata_substrate[idx, :]).type(torch.float),
 
 
190
  }
191
 
192
  return embedding, metadata
193
-
194
 
195
  class ImageMetadataDataset(Dataset):
196
  def __init__(self, df, transform=None, local_filepath=None):
197
  self.df = df
198
  self.transform = transform
199
  self.local_filepath = local_filepath
200
-
201
- self.filepaths = df["image_path"].apply(lambda x: x.replace("jpg", "JPG")).to_list()
 
 
202
  self.metadata_date = df[TIME].to_numpy()
203
  self.metadata_geo = df[GEO].to_numpy()
204
  self.metadata_substrate = df[SUBSTRATE].to_numpy()
205
 
206
-
207
  def __len__(self):
208
  return len(self.df)
209
 
@@ -223,16 +235,20 @@ class ImageMetadataDataset(Dataset):
223
  metadata = {
224
  "date": torch.from_numpy(self.metadata_date[idx, :]).type(torch.float),
225
  "geo": torch.from_numpy(self.metadata_geo[idx, :]).type(torch.float),
226
- "substr": torch.from_numpy(self.metadata_substrate[idx, :]).type(torch.float),
 
 
227
  }
228
 
229
  return image, metadata
230
 
 
231
  DATE_SIZE = 4
232
  GEO_SIZE = 7
233
  SUBSTRATE_SIZE = 73
234
  NUM_CLASSES = 1717
235
 
 
236
  class StarReLU(nn.Module):
237
  """
238
  StarReLU: s * relu(x) ** 2 + b
@@ -260,6 +276,7 @@ class StarReLU(nn.Module):
260
  def forward(self, x):
261
  return self.scale * self.relu(x) ** 2 + self.bias
262
 
 
263
  class FungiMEEModel(nn.Module):
264
  def __init__(
265
  self,
@@ -272,7 +289,6 @@ class FungiMEEModel(nn.Module):
272
  self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
273
  print(f"Using devide: {self.device}")
274
 
275
-
276
  self.date_embedding = MlpHead(
277
  dim=DATE_SIZE, num_classes=dim, mlp_ratio=128, act_layer=StarReLU
278
  )
@@ -286,38 +302,43 @@ class FungiMEEModel(nn.Module):
286
  act_layer=StarReLU,
287
  )
288
 
289
- self.encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=dim, nhead=8, batch_first=True), num_layers=4)
290
-
 
 
 
291
  self.head = MlpHead(dim=dim, num_classes=num_classes, drop_rate=0)
292
 
293
  for param in self.parameters():
294
  if param.dim() > 1:
295
  nn.init.kaiming_normal_(param)
296
 
297
-
298
  def forward(self, img_emb, metadata):
299
 
300
  img_emb = img_emb.to(self.device)
301
-
302
  date_emb = self.date_embedding.forward(metadata["date"].to(self.device))
303
  geo_emb = self.geo_embedding.forward(metadata["geo"].to(self.device))
304
  substr_emb = self.substr_embedding.forward(metadata["substr"].to(self.device))
305
 
306
- full_emb = torch.stack((img_emb, date_emb, geo_emb, substr_emb), dim=1) #.unsqueeze(0)
 
 
307
  # print(full_emb.shape)
308
 
309
  cls_emb = self.encoder.forward(full_emb)[:, 0, :].squeeze(1)
310
 
311
  return self.head.forward(cls_emb)
312
-
313
  def predict(self, img_emb, metadata):
314
-
315
  logits = self.forward(img_emb, metadata)
316
 
317
  # Any preprocess happens here
318
 
319
  return logits.argmax(1).tolist()
320
-
 
321
  class FungiEnsembleModel(nn.Module):
322
 
323
  def __init__(self, models, softmax=True) -> None:
@@ -331,39 +352,46 @@ class FungiEnsembleModel(nn.Module):
331
  model = model.to(self.device)
332
  model.eval()
333
  self.models.append(model)
334
-
335
  def forward(self, img_emb, metadata):
336
 
337
  img_emb = img_emb.to(self.device)
338
 
339
- probs = []
340
 
341
  for model in self.models:
342
  logits = model.forward(img_emb, metadata)
343
-
344
- p = logits.softmax(dim=1).detach().cpu() if self.softmax else logits.detach().cpu()
 
 
 
 
345
 
346
  probs.append(p)
347
 
348
  return torch.stack(probs).mean(dim=0)
349
-
350
  def predict(self, img_emb, metadata):
351
-
352
  logits = self.forward(img_emb, metadata)
353
 
354
  # Any preprocess happens here
355
 
356
  return logits.argmax(1).tolist()
357
-
358
 
359
  def is_gpu_available():
360
  """Check if the python package `onnxruntime-gpu` is installed."""
361
  return torch.cuda.is_available()
362
 
 
363
  class PytorchWorker:
364
  """Run inference using ONNX runtime."""
365
 
366
- def __init__(self, model_path: str, model_name: str, number_of_categories: int = 1605):
 
 
367
 
368
  def _load_model(model_name, model_path):
369
 
@@ -379,10 +407,13 @@ class PytorchWorker:
379
 
380
  self.model = _load_model(model_name, model_path)
381
 
382
- self.transforms = T.Compose([T.Resize((518, 518)),
383
- T.ToTensor(),
384
- T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
385
-
 
 
 
386
 
387
  def predict_image(self, image: np.ndarray):
388
  """Run inference using ONNX runtime.
@@ -397,9 +428,9 @@ class PytorchWorker:
397
 
398
 
399
  def make_submission(metadata_df, model_names=None):
400
-
401
- OUTPUT_CSV_PATH="./submission.csv"
402
-
403
  """Make submission with given """
404
 
405
  BASE_CKPT_PATH = "./checkpoints"
@@ -414,12 +445,14 @@ def make_submission(metadata_df, model_names=None):
414
 
415
  ckpt = torch.load(ckpt_path)
416
  model = FungiMEEModel()
417
- model.load_state_dict({w: ckpt['state_dict']["model." + w] for w in model.state_dict().keys()})
 
 
418
  model.eval()
419
  model.cuda()
420
 
421
  models.append(model)
422
-
423
  ensemble_model = FungiEnsembleModel(models)
424
 
425
  embedding_dataset = EmbeddingMetadataDataset(metadata_df)
@@ -433,31 +466,39 @@ def make_submission(metadata_df, model_names=None):
433
 
434
  all_preds = torch.vstack(preds).numpy()
435
 
436
- preds_df = metadata_df[['observation_id', 'image_path']]
437
- preds_df['preds'] = [i for i in all_preds]
438
- preds_df = preds_df[['observation_id', 'preds']].groupby('observation_id').mean().reset_index()
439
- preds_df['class_id'] = preds_df['preds'].apply(lambda x: x.argmax() if x.argmax() <= 1603 else -1)
440
- preds_df[['observation_id', 'class_id']].to_csv(OUTPUT_CSV_PATH, index=None)
 
 
 
 
 
 
 
441
 
442
  print("Submission complete")
443
 
 
444
  if __name__ == "__main__":
445
 
446
  MODEL_PATH = "metaformer-s-224.pth"
447
  MODEL_NAME = "timm/vit_base_patch14_reg4_dinov2.lvd142m"
448
 
449
  # # Real submission
450
- import zipfile
451
 
452
- with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
453
- zip_ref.extractall("/tmp/data")
454
 
455
- metadata_file_path = "./_test_preprocessed.csv"
456
- root_dir = "/tmp/data"
457
 
458
  # Test submission
459
- # metadata_file_path = "../trial_submission.csv"
460
- # root_dir = "../data/DF_FULL"
461
 
462
  ##############
463
 
 
 
1
  import os
2
  from typing import List
3
 
 
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
  import torchvision.transforms as T
12
+ from albumentations import (
13
+ CenterCrop,
14
+ Compose,
15
+ HorizontalFlip,
16
+ Normalize,
17
+ PadIfNeeded,
18
+ RandomBrightnessContrast,
19
+ RandomCrop,
20
+ RandomResizedCrop,
21
+ Resize,
22
+ VerticalFlip,
23
+ )
24
  from albumentations.pytorch import ToTensorV2
25
  from PIL import Image
26
  from timm.layers import LayerNorm2d, SelectAdaptivePool2d
 
28
  from torch.utils.data import DataLoader, Dataset
29
  from tqdm import tqdm
30
 
31
+ DIM = 518
 
32
 
33
  def get_transforms(*, data, model=None, width=None, height=None):
34
  assert data in ("train", "valid")
35
 
36
+ width = width if width else DIM
37
+ height = height if height else DIM
38
 
39
  model_mean = list(model.default_cfg["mean"]) if model else (0.5, 0.5, 0.5)
40
  model_std = list(model.default_cfg["std"]) if model else (0.5, 0.5, 0.5)
 
60
  ]
61
  )
62
 
 
 
63
 
64
  def generate_embeddings(metadata_file_path, root_dir):
65
 
 
74
  loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)
75
 
76
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
77
+ model = timm.create_model(
78
+ "timm/vit_large_patch14_reg4_dinov2.lvd142m", pretrained=True
79
+ )
80
  model = model.to(device)
81
  model.eval()
82
 
 
94
 
95
  embs_list = [x for x in all_embs]
96
  metadata_df["embedding"] = embs_list
97
+
98
  return metadata_df
99
 
100
 
101
+ TIME = ["m0", "m1", "d0", "d1"]
102
+ GEO = ["g0", "g1", "g2", "g3", "g4", "g5", "g_float"]
103
+ SUBSTRATE = [
104
+ "substrate_0",
105
  "substrate_1",
106
  "substrate_2",
107
  "substrate_3",
 
176
  "habitat_31",
177
  ]
178
 
179
+
180
  class EmbeddingMetadataDataset(Dataset):
181
  def __init__(self, df):
182
  self.df = df
183
 
184
+ self.emb = df["embedding"]
185
  self.metadata_date = df[TIME].to_numpy()
186
  self.metadata_geo = df[GEO].to_numpy()
187
  self.metadata_substrate = df[SUBSTRATE].to_numpy()
 
195
  metadata = {
196
  "date": torch.from_numpy(self.metadata_date[idx, :]).type(torch.float),
197
  "geo": torch.from_numpy(self.metadata_geo[idx, :]).type(torch.float),
198
+ "substr": torch.from_numpy(self.metadata_substrate[idx, :]).type(
199
+ torch.float
200
+ ),
201
  }
202
 
203
  return embedding, metadata
204
+
205
 
206
  class ImageMetadataDataset(Dataset):
207
  def __init__(self, df, transform=None, local_filepath=None):
208
  self.df = df
209
  self.transform = transform
210
  self.local_filepath = local_filepath
211
+
212
+ self.filepaths = (
213
+ df["image_path"].to_list()
214
+ )
215
  self.metadata_date = df[TIME].to_numpy()
216
  self.metadata_geo = df[GEO].to_numpy()
217
  self.metadata_substrate = df[SUBSTRATE].to_numpy()
218
 
 
219
  def __len__(self):
220
  return len(self.df)
221
 
 
235
  metadata = {
236
  "date": torch.from_numpy(self.metadata_date[idx, :]).type(torch.float),
237
  "geo": torch.from_numpy(self.metadata_geo[idx, :]).type(torch.float),
238
+ "substr": torch.from_numpy(self.metadata_substrate[idx, :]).type(
239
+ torch.float
240
+ ),
241
  }
242
 
243
  return image, metadata
244
 
245
+
246
  DATE_SIZE = 4
247
  GEO_SIZE = 7
248
  SUBSTRATE_SIZE = 73
249
  NUM_CLASSES = 1717
250
 
251
+
252
  class StarReLU(nn.Module):
253
  """
254
  StarReLU: s * relu(x) ** 2 + b
 
276
  def forward(self, x):
277
  return self.scale * self.relu(x) ** 2 + self.bias
278
 
279
+
280
  class FungiMEEModel(nn.Module):
281
  def __init__(
282
  self,
 
289
  self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
290
  print(f"Using devide: {self.device}")
291
 
 
292
  self.date_embedding = MlpHead(
293
  dim=DATE_SIZE, num_classes=dim, mlp_ratio=128, act_layer=StarReLU
294
  )
 
302
  act_layer=StarReLU,
303
  )
304
 
305
+ self.encoder = nn.TransformerEncoder(
306
+ nn.TransformerEncoderLayer(d_model=dim, nhead=8, batch_first=True),
307
+ num_layers=4,
308
+ )
309
+
310
  self.head = MlpHead(dim=dim, num_classes=num_classes, drop_rate=0)
311
 
312
  for param in self.parameters():
313
  if param.dim() > 1:
314
  nn.init.kaiming_normal_(param)
315
 
 
316
  def forward(self, img_emb, metadata):
317
 
318
  img_emb = img_emb.to(self.device)
319
+
320
  date_emb = self.date_embedding.forward(metadata["date"].to(self.device))
321
  geo_emb = self.geo_embedding.forward(metadata["geo"].to(self.device))
322
  substr_emb = self.substr_embedding.forward(metadata["substr"].to(self.device))
323
 
324
+ full_emb = torch.stack(
325
+ (img_emb, date_emb, geo_emb, substr_emb), dim=1
326
+ ) # .unsqueeze(0)
327
  # print(full_emb.shape)
328
 
329
  cls_emb = self.encoder.forward(full_emb)[:, 0, :].squeeze(1)
330
 
331
  return self.head.forward(cls_emb)
332
+
333
  def predict(self, img_emb, metadata):
334
+
335
  logits = self.forward(img_emb, metadata)
336
 
337
  # Any preprocess happens here
338
 
339
  return logits.argmax(1).tolist()
340
+
341
+
342
  class FungiEnsembleModel(nn.Module):
343
 
344
  def __init__(self, models, softmax=True) -> None:
 
352
  model = model.to(self.device)
353
  model.eval()
354
  self.models.append(model)
355
+
356
  def forward(self, img_emb, metadata):
357
 
358
  img_emb = img_emb.to(self.device)
359
 
360
+ probs = []
361
 
362
  for model in self.models:
363
  logits = model.forward(img_emb, metadata)
364
+
365
+ p = (
366
+ logits.softmax(dim=1).detach().cpu()
367
+ if self.softmax
368
+ else logits.detach().cpu()
369
+ )
370
 
371
  probs.append(p)
372
 
373
  return torch.stack(probs).mean(dim=0)
374
+
375
  def predict(self, img_emb, metadata):
376
+
377
  logits = self.forward(img_emb, metadata)
378
 
379
  # Any preprocess happens here
380
 
381
  return logits.argmax(1).tolist()
382
+
383
 
384
  def is_gpu_available():
385
  """Check if the python package `onnxruntime-gpu` is installed."""
386
  return torch.cuda.is_available()
387
 
388
+
389
  class PytorchWorker:
390
  """Run inference using ONNX runtime."""
391
 
392
+ def __init__(
393
+ self, model_path: str, model_name: str, number_of_categories: int = 1605
394
+ ):
395
 
396
  def _load_model(model_name, model_path):
397
 
 
407
 
408
  self.model = _load_model(model_name, model_path)
409
 
410
+ self.transforms = T.Compose(
411
+ [
412
+ T.Resize((518, 518)),
413
+ T.ToTensor(),
414
+ T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
415
+ ]
416
+ )
417
 
418
  def predict_image(self, image: np.ndarray):
419
  """Run inference using ONNX runtime.
 
428
 
429
 
430
  def make_submission(metadata_df, model_names=None):
431
+
432
+ OUTPUT_CSV_PATH = "./submission.csv"
433
+
434
  """Make submission with given """
435
 
436
  BASE_CKPT_PATH = "./checkpoints"
 
445
 
446
  ckpt = torch.load(ckpt_path)
447
  model = FungiMEEModel()
448
+ model.load_state_dict(
449
+ {w: ckpt["state_dict"]["model." + w] for w in model.state_dict().keys()}
450
+ )
451
  model.eval()
452
  model.cuda()
453
 
454
  models.append(model)
455
+
456
  ensemble_model = FungiEnsembleModel(models)
457
 
458
  embedding_dataset = EmbeddingMetadataDataset(metadata_df)
 
466
 
467
  all_preds = torch.vstack(preds).numpy()
468
 
469
+ preds_df = metadata_df[["observation_id", "image_path"]]
470
+ preds_df["preds"] = [i for i in all_preds]
471
+ preds_df = (
472
+ preds_df[["observation_id", "preds"]]
473
+ .groupby("observation_id")
474
+ .mean()
475
+ .reset_index()
476
+ )
477
+ preds_df["class_id"] = preds_df["preds"].apply(
478
+ lambda x: x.argmax() if x.argmax() <= 1603 else -1
479
+ )
480
+ preds_df[["observation_id", "class_id"]].to_csv(OUTPUT_CSV_PATH, index=None)
481
 
482
  print("Submission complete")
483
 
484
+
485
  if __name__ == "__main__":
486
 
487
  MODEL_PATH = "metaformer-s-224.pth"
488
  MODEL_NAME = "timm/vit_base_patch14_reg4_dinov2.lvd142m"
489
 
490
  # # Real submission
491
+ # import zipfile
492
 
493
+ # with zipfile.ZipFile("/tmp/data/private_testset.zip", "r") as zip_ref:
494
+ # zip_ref.extractall("/tmp/data")
495
 
496
+ # metadata_file_path = "./_test_preprocessed.csv"
497
+ # root_dir = "/tmp/data"
498
 
499
  # Test submission
500
+ metadata_file_path = "../trial_submission.csv"
501
+ root_dir = "../data/DF_FULL"
502
 
503
  ##############
504