chychiu commited on
Commit
8f2252f
·
1 Parent(s): 3735d66
Files changed (1) hide show
  1. script.py +399 -43
script.py CHANGED
@@ -1,12 +1,360 @@
1
- import pandas as pd
2
- import numpy as np
3
  import os
4
- from tqdm import tqdm
 
 
 
 
5
  import timm
 
 
 
6
  import torchvision.transforms as T
 
 
 
 
7
  from PIL import Image
8
- import torch
9
- from typing import List
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def is_gpu_available():
12
  """Check if the python package `onnxruntime-gpu` is installed."""
@@ -48,63 +396,71 @@ class PytorchWorker:
48
  return [-1]
49
 
50
 
51
- def make_submission(test_metadata, model_path, model_name, output_csv_path="./submission.csv", images_root_path="/tmp/data/private_testset"):
 
 
 
52
  """Make submission with given """
53
 
54
- model = PytorchWorker(model_path, model_name)
55
-
56
- predictions = []
57
 
58
- for _, row in tqdm(test_metadata.iterrows(), total=len(test_metadata)):
59
- image_path = os.path.join(images_root_path, row.image_path) #.replace("jpg", "JPG"))
60
 
61
- test_image = Image.open(image_path).convert("RGB")
62
 
63
- logits = model.predict_image(test_image)
 
 
64
 
65
- predictions.append(np.argmax(logits))
 
 
 
 
66
 
67
- test_metadata["class_id"] = predictions
68
 
69
- user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first")
70
 
71
- for ix, row in user_pred_df.iterrows():
72
- if row['class_id'] == 1604:
73
- user_pred_df.loc[ix, 'class_id'] = -1
74
 
75
- user_pred_df[["observation_id", "class_id"]].to_csv(output_csv_path, index=None)
 
 
 
 
76
 
77
- if __name__ == "__main__":
78
 
79
- MODEL_PATH = "metaformer-s-224.pth"
80
- MODEL_NAME = "timm/vit_base_patch14_reg4_dinov2.lvd142m"
 
 
 
81
 
82
- # Real submission
83
- # import zipfile
84
 
85
- # with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
86
- # zip_ref.extractall("/tmp/data")
87
 
88
- # metadata_file_path = "./test_preprocessed.csv"
89
- # test_metadata = pd.read_csv(metadata_file_path)
90
 
91
- # make_submission(
92
- # test_metadata=test_metadata,
93
- # model_path=MODEL_PATH,
94
- # model_name=MODEL_NAME
95
- # )
96
 
97
- # Test submission
 
98
 
99
- metadata_file_path = "../trial_submission.csv"
 
100
 
101
- test_metadata = pd.read_csv(metadata_file_path)
 
 
102
 
103
- make_submission(
104
- test_metadata=test_metadata,
105
- model_path=MODEL_PATH,
106
- model_name=MODEL_NAME,
107
- images_root_path="../data/DF_FULL"
108
- )
109
 
 
110
 
 
 
1
+ import io
 
2
  import os
3
+ from typing import List
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import pandas as pd
8
  import timm
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
  import torchvision.transforms as T
13
+ from albumentations import (CenterCrop, Compose, HorizontalFlip, Normalize,
14
+ PadIfNeeded, RandomBrightnessContrast, RandomCrop,
15
+ RandomResizedCrop, Resize, VerticalFlip)
16
+ from albumentations.pytorch import ToTensorV2
17
  from PIL import Image
18
+ from timm.layers import LayerNorm2d, SelectAdaptivePool2d
19
+ from timm.models.metaformer import MlpHead
20
+ from torch.utils.data import DataLoader, Dataset
21
+ from tqdm import tqdm
22
+
23
+ DEFAULT_WIDTH = 518
24
+ DEFAULT_HEIGHT = 518
25
+
26
+ def get_transforms(*, data, model=None, width=None, height=None):
27
+ assert data in ("train", "valid")
28
+
29
+ width = width if width else DEFAULT_WIDTH
30
+ height = height if height else DEFAULT_HEIGHT
31
+
32
+ model_mean = list(model.default_cfg["mean"]) if model else (0.5, 0.5, 0.5)
33
+ model_std = list(model.default_cfg["std"]) if model else (0.5, 0.5, 0.5)
34
+
35
+ if data == "train":
36
+ return Compose(
37
+ [
38
+ RandomResizedCrop(width, height, scale=(0.6, 1.0)),
39
+ HorizontalFlip(p=0.5),
40
+ VerticalFlip(p=0.5),
41
+ RandomBrightnessContrast(p=0.2),
42
+ Normalize(mean=model_mean, std=model_std),
43
+ ToTensorV2(),
44
+ ]
45
+ )
46
+
47
+ elif data == "valid":
48
+ return Compose(
49
+ [
50
+ Resize(width, height),
51
+ Normalize(mean=model_mean, std=model_std),
52
+ ToTensorV2(),
53
+ ]
54
+ )
55
+
56
+ DIM = 518
57
+ BASE_PATH = "../data/DF_FULL"
58
+
59
+ def generate_embeddings(metadata_file_path, root_dir):
60
+
61
+ metadata_df = pd.read_csv(metadata_file_path)
62
+
63
+ transforms = get_transforms(data="valid", width=DIM, height=DIM)
64
+
65
+ test_dataset = ImageMetadataDataset(
66
+ metadata_df, local_filepath=root_dir, transform=transforms
67
+ )
68
+
69
+ loader = DataLoader(test_dataset, batch_size=3, shuffle=False, num_workers=4)
70
+
71
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
72
+ model = timm.create_model("timm/vit_large_patch14_reg4_dinov2.lvd142m", pretrained=True)
73
+ model = model.to(device)
74
+ model.eval()
75
+
76
+ all_embs = []
77
+ for data in tqdm(loader):
78
+
79
+ img, _ = data
80
+ img = img.to(device)
81
+
82
+ emb = model.forward(img)
83
+
84
+ all_embs.append(emb.detach().cpu().numpy())
85
+
86
+ all_embs = np.vstack(all_embs)
87
+
88
+ embs_list = [x for x in all_embs]
89
+ metadata_df["embedding"] = embs_list
90
+
91
+ return metadata_df
92
+
93
+
94
+ TIME = ['m0', 'm1', 'd0', 'd1']
95
+ GEO = ['g0', 'g1', 'g2', 'g3', 'g4', 'g5', 'g_float']
96
+ SUBSTRATE = ["substrate_0",
97
+ "substrate_1",
98
+ "substrate_2",
99
+ "substrate_3",
100
+ "substrate_4",
101
+ "substrate_5",
102
+ "substrate_6",
103
+ "substrate_7",
104
+ "substrate_8",
105
+ "substrate_9",
106
+ "substrate_10",
107
+ "substrate_11",
108
+ "substrate_12",
109
+ "substrate_13",
110
+ "substrate_14",
111
+ "substrate_15",
112
+ "substrate_16",
113
+ "substrate_17",
114
+ "substrate_18",
115
+ "substrate_19",
116
+ "substrate_20",
117
+ "substrate_21",
118
+ "substrate_22",
119
+ "substrate_23",
120
+ "substrate_24",
121
+ "substrate_25",
122
+ "substrate_26",
123
+ "substrate_27",
124
+ "substrate_28",
125
+ "substrate_29",
126
+ "substrate_30",
127
+ "metasubstrate_0",
128
+ "metasubstrate_1",
129
+ "metasubstrate_2",
130
+ "metasubstrate_3",
131
+ "metasubstrate_4",
132
+ "metasubstrate_5",
133
+ "metasubstrate_6",
134
+ "metasubstrate_7",
135
+ "metasubstrate_8",
136
+ "metasubstrate_9",
137
+ "habitat_0",
138
+ "habitat_1",
139
+ "habitat_2",
140
+ "habitat_3",
141
+ "habitat_4",
142
+ "habitat_5",
143
+ "habitat_6",
144
+ "habitat_7",
145
+ "habitat_8",
146
+ "habitat_9",
147
+ "habitat_10",
148
+ "habitat_11",
149
+ "habitat_12",
150
+ "habitat_13",
151
+ "habitat_14",
152
+ "habitat_15",
153
+ "habitat_16",
154
+ "habitat_17",
155
+ "habitat_18",
156
+ "habitat_19",
157
+ "habitat_20",
158
+ "habitat_21",
159
+ "habitat_22",
160
+ "habitat_23",
161
+ "habitat_24",
162
+ "habitat_25",
163
+ "habitat_26",
164
+ "habitat_27",
165
+ "habitat_28",
166
+ "habitat_29",
167
+ "habitat_30",
168
+ "habitat_31",
169
+ ]
170
+
171
+ class EmbeddingMetadataDataset(Dataset):
172
+ def __init__(self, df):
173
+ self.df = df
174
+
175
+ self.emb = df['embedding']
176
+ self.metadata_date = df[TIME].to_numpy()
177
+ self.metadata_geo = df[GEO].to_numpy()
178
+ self.metadata_substrate = df[SUBSTRATE].to_numpy()
179
+
180
+ def __len__(self):
181
+ return len(self.df)
182
+
183
+ def __getitem__(self, idx):
184
+ embedding = torch.Tensor(self.emb[idx].copy()).type(torch.float)
185
+
186
+ metadata = {
187
+ "date": torch.from_numpy(self.metadata_date[idx, :]).type(torch.float),
188
+ "geo": torch.from_numpy(self.metadata_geo[idx, :]).type(torch.float),
189
+ "substr": torch.from_numpy(self.metadata_substrate[idx, :]).type(torch.float),
190
+ }
191
+
192
+ return embedding, metadata
193
+
194
+
195
+ class ImageMetadataDataset(Dataset):
196
+ def __init__(self, df, transform=None, local_filepath=None):
197
+ self.df = df
198
+ self.transform = transform
199
+ self.local_filepath = local_filepath
200
+
201
+ self.filepaths = df["image_path"].apply(lambda x: x.replace("jpg", "JPG")).to_list()
202
+ self.metadata_date = df[TIME].to_numpy()
203
+ self.metadata_geo = df[GEO].to_numpy()
204
+ self.metadata_substrate = df[SUBSTRATE].to_numpy()
205
+
206
+
207
+ def __len__(self):
208
+ return len(self.df)
209
+
210
+ def __getitem__(self, idx):
211
+ file_path = os.path.join(self.local_filepath, self.filepaths[idx])
212
+
213
+ try:
214
+ image = cv2.imread(file_path)
215
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
216
+ except:
217
+ print(file_path)
218
+
219
+ if self.transform:
220
+ augmented = self.transform(image=image)
221
+ image = augmented["image"]
222
+
223
+ metadata = {
224
+ "date": torch.from_numpy(self.metadata_date[idx, :]).type(torch.float),
225
+ "geo": torch.from_numpy(self.metadata_geo[idx, :]).type(torch.float),
226
+ "substr": torch.from_numpy(self.metadata_substrate[idx, :]).type(torch.float),
227
+ }
228
+
229
+ return image, metadata
230
+
231
+ DATE_SIZE = 4
232
+ GEO_SIZE = 7
233
+ SUBSTRATE_SIZE = 73
234
+ NUM_CLASSES = 1717
235
+
236
+ class StarReLU(nn.Module):
237
+ """
238
+ StarReLU: s * relu(x) ** 2 + b
239
+ """
240
+
241
+ def __init__(
242
+ self,
243
+ scale_value=1.0,
244
+ bias_value=0.0,
245
+ scale_learnable=True,
246
+ bias_learnable=True,
247
+ mode=None,
248
+ inplace=False,
249
+ ):
250
+ super().__init__()
251
+ self.inplace = inplace
252
+ self.relu = nn.ReLU(inplace=inplace)
253
+ self.scale = nn.Parameter(
254
+ scale_value * torch.ones(1), requires_grad=scale_learnable
255
+ )
256
+ self.bias = nn.Parameter(
257
+ bias_value * torch.ones(1), requires_grad=bias_learnable
258
+ )
259
+
260
+ def forward(self, x):
261
+ return self.scale * self.relu(x) ** 2 + self.bias
262
+
263
+ class FungiMEEModel(nn.Module):
264
+ def __init__(
265
+ self,
266
+ num_classes=NUM_CLASSES,
267
+ dim=1024,
268
+ ):
269
+ super().__init__()
270
+
271
+ print("Setting up Pytorch Model")
272
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
273
+ print(f"Using devide: {self.device}")
274
+
275
+
276
+ self.date_embedding = MlpHead(
277
+ dim=DATE_SIZE, num_classes=dim, mlp_ratio=128, act_layer=StarReLU
278
+ )
279
+ self.geo_embedding = MlpHead(
280
+ dim=GEO_SIZE, num_classes=dim, mlp_ratio=128, act_layer=StarReLU
281
+ )
282
+ self.substr_embedding = MlpHead(
283
+ dim=SUBSTRATE_SIZE,
284
+ num_classes=dim,
285
+ mlp_ratio=8,
286
+ act_layer=StarReLU,
287
+ )
288
+
289
+ self.encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=dim, nhead=8, batch_first=True), num_layers=4)
290
+
291
+ self.head = MlpHead(dim=dim, num_classes=num_classes, drop_rate=0)
292
+
293
+ for param in self.parameters():
294
+ if param.dim() > 1:
295
+ nn.init.kaiming_normal_(param)
296
+
297
+
298
+ def forward(self, img_emb, metadata):
299
+
300
+ img_emb = img_emb.to(self.device)
301
+
302
+ date_emb = self.date_embedding.forward(metadata["date"].to(self.device))
303
+ geo_emb = self.geo_embedding.forward(metadata["geo"].to(self.device))
304
+ substr_emb = self.substr_embedding.forward(metadata["substr"].to(self.device))
305
+
306
+ full_emb = torch.stack((img_emb, date_emb, geo_emb, substr_emb), dim=1) #.unsqueeze(0)
307
+ # print(full_emb.shape)
308
+
309
+ cls_emb = self.encoder.forward(full_emb)[:, 0, :].squeeze(1)
310
+
311
+ return self.head.forward(cls_emb)
312
+
313
+ def predict(self, img_emb, metadata):
314
+
315
+ logits = self.forward(img_emb, metadata)
316
+
317
+ # Any preprocess happens here
318
+
319
+ return logits.argmax(1).tolist()
320
+
321
+ class FungiEnsembleModel(nn.Module):
322
+
323
+ def __init__(self, models, softmax=True) -> None:
324
+ super().__init__()
325
+
326
+ self.models = nn.ModuleList()
327
+ self.softmax = softmax
328
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
329
+
330
+ for model in models:
331
+ model = model.to(self.device)
332
+ model.eval()
333
+ self.models.append(model)
334
+
335
+ def forward(self, img_emb, metadata):
336
+
337
+ img_emb = img_emb.to(self.device)
338
+
339
+ probs = []
340
+
341
+ for model in self.models:
342
+ logits = model.forward(img_emb, metadata)
343
+
344
+ p = logits.softmax(dim=1).detach().cpu() if self.softmax else logits.detach().cpu()
345
+
346
+ probs.append(p)
347
+
348
+ return torch.stack(probs).mean(dim=0)
349
+
350
+ def predict(self, img_emb, metadata):
351
+
352
+ logits = self.forward(img_emb, metadata)
353
+
354
+ # Any preprocess happens here
355
+
356
+ return logits.argmax(1).tolist()
357
+
358
 
359
  def is_gpu_available():
360
  """Check if the python package `onnxruntime-gpu` is installed."""
 
396
  return [-1]
397
 
398
 
399
+ def make_submission(metadata_df, model_names=None):
400
+
401
+ OUTPUT_CSV_PATH="./submission.csv"
402
+
403
  """Make submission with given """
404
 
405
+ BASE_CKPT_PATH = "./checkpoints"
 
 
406
 
407
+ model_names = model_names or os.listdir(BASE_CKPT_PATH)
 
408
 
409
+ models = []
410
 
411
+ for model_path in model_names:
412
+ print("loading ", model_path)
413
+ ckpt_path = os.path.join(BASE_CKPT_PATH, model_path)
414
 
415
+ ckpt = torch.load(ckpt_path)
416
+ model = FungiMEEModel()
417
+ model.load_state_dict({w: ckpt['state_dict']["model." + w] for w in model.state_dict().keys()})
418
+ model.eval()
419
+ model.cuda()
420
 
421
+ models.append(model)
422
 
423
+ ensemble_model = FungiEnsembleModel(models)
424
 
425
+ embedding_dataset = EmbeddingMetadataDataset(metadata_df)
426
+ loader = DataLoader(embedding_dataset, batch_size=128, shuffle=False)
 
427
 
428
+ preds = []
429
+ for data in tqdm(loader):
430
+ emb, metadata = data
431
+ pred = ensemble_model.forward(emb, metadata)
432
+ preds.append(pred)
433
 
434
+ all_preds = torch.vstack(preds).numpy()
435
 
436
+ preds_df = metadata_df[['observation_id', 'image_path']]
437
+ preds_df['preds'] = [i for i in all_preds]
438
+ preds_df = preds_df[['observation_id', 'preds']].groupby('observation_id').mean().reset_index()
439
+ preds_df['class_id'] = preds_df['preds'].apply(lambda x: x.argmax() if x.argmax() <= 1603 else -1)
440
+ preds_df[['observation_id', 'class_id']].to_csv(OUTPUT_CSV_PATH, index=None)
441
 
442
+ print("Submission complete")
 
443
 
444
+ if __name__ == "__main__":
 
445
 
446
+ MODEL_PATH = "metaformer-s-224.pth"
447
+ MODEL_NAME = "timm/vit_base_patch14_reg4_dinov2.lvd142m"
448
 
449
+ # # Real submission
450
+ import zipfile
 
 
 
451
 
452
+ with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
453
+ zip_ref.extractall("/tmp/data")
454
 
455
+ metadata_file_path = "./_test_preprocessed.csv"
456
+ root_dir = "/tmp/data"
457
 
458
+ # # Test submission
459
+ # metadata_file_path = "../trial_submission.csv"
460
+ # root_dir = "../data/DF_FULL"
461
 
462
+ ##############
 
 
 
 
 
463
 
464
+ metadata_df = generate_embeddings(metadata_file_path, root_dir)
465
 
466
+ make_submission(metadata_df)