|
import os |
|
from typing import List |
|
|
|
import cv2 |
|
import numpy as np |
|
import pandas as pd |
|
import timm |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torchvision.transforms as T |
|
from albumentations import ( |
|
CenterCrop, |
|
Compose, |
|
HorizontalFlip, |
|
Normalize, |
|
PadIfNeeded, |
|
RandomBrightnessContrast, |
|
RandomCrop, |
|
RandomResizedCrop, |
|
Resize, |
|
VerticalFlip, |
|
) |
|
from albumentations.pytorch import ToTensorV2 |
|
from PIL import Image |
|
from timm.layers import LayerNorm2d, SelectAdaptivePool2d |
|
from timm.models.metaformer import MlpHead |
|
from torch.utils.data import DataLoader, Dataset |
|
from tqdm import tqdm |
|
|
|
DIM = 518 |
|
|
|
def get_transforms(*, data, model=None, width=None, height=None): |
|
assert data in ("train", "valid") |
|
|
|
width = width if width else DIM |
|
height = height if height else DIM |
|
|
|
model_mean = list(model.default_cfg["mean"]) if model else (0.5, 0.5, 0.5) |
|
model_std = list(model.default_cfg["std"]) if model else (0.5, 0.5, 0.5) |
|
|
|
if data == "train": |
|
return Compose( |
|
[ |
|
RandomResizedCrop(width, height, scale=(0.6, 1.0)), |
|
HorizontalFlip(p=0.5), |
|
VerticalFlip(p=0.5), |
|
RandomBrightnessContrast(p=0.2), |
|
Normalize(mean=model_mean, std=model_std), |
|
ToTensorV2(), |
|
] |
|
) |
|
|
|
elif data == "valid": |
|
return Compose( |
|
[ |
|
Resize(width, height), |
|
Normalize(mean=model_mean, std=model_std), |
|
ToTensorV2(), |
|
] |
|
) |
|
|
|
|
|
def generate_embeddings(metadata_file_path, root_dir): |
|
|
|
metadata_df = pd.read_csv(metadata_file_path) |
|
|
|
transforms = get_transforms(data="valid", width=DIM, height=DIM) |
|
|
|
test_dataset = ImageMetadataDataset( |
|
metadata_df, local_filepath=root_dir, transform=transforms |
|
) |
|
|
|
loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4) |
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
model = timm.create_model( |
|
"timm/vit_large_patch14_reg4_dinov2.lvd142m", pretrained=True |
|
) |
|
model = model.to(device) |
|
model.eval() |
|
|
|
all_embs = [] |
|
for data in tqdm(loader): |
|
|
|
img, _ = data |
|
img = img.to(device) |
|
|
|
emb = model.forward(img) |
|
|
|
all_embs.append(emb.detach().cpu().numpy()) |
|
|
|
all_embs = np.vstack(all_embs) |
|
|
|
embs_list = [x for x in all_embs] |
|
metadata_df["embedding"] = embs_list |
|
|
|
return metadata_df |
|
|
|
|
|
TIME = ["m0", "m1", "d0", "d1"] |
|
GEO = ["g0", "g1", "g2", "g3", "g4", "g5", "g_float"] |
|
SUBSTRATE = [ |
|
"substrate_0", |
|
"substrate_1", |
|
"substrate_2", |
|
"substrate_3", |
|
"substrate_4", |
|
"substrate_5", |
|
"substrate_6", |
|
"substrate_7", |
|
"substrate_8", |
|
"substrate_9", |
|
"substrate_10", |
|
"substrate_11", |
|
"substrate_12", |
|
"substrate_13", |
|
"substrate_14", |
|
"substrate_15", |
|
"substrate_16", |
|
"substrate_17", |
|
"substrate_18", |
|
"substrate_19", |
|
"substrate_20", |
|
"substrate_21", |
|
"substrate_22", |
|
"substrate_23", |
|
"substrate_24", |
|
"substrate_25", |
|
"substrate_26", |
|
"substrate_27", |
|
"substrate_28", |
|
"substrate_29", |
|
"substrate_30", |
|
"metasubstrate_0", |
|
"metasubstrate_1", |
|
"metasubstrate_2", |
|
"metasubstrate_3", |
|
"metasubstrate_4", |
|
"metasubstrate_5", |
|
"metasubstrate_6", |
|
"metasubstrate_7", |
|
"metasubstrate_8", |
|
"metasubstrate_9", |
|
"habitat_0", |
|
"habitat_1", |
|
"habitat_2", |
|
"habitat_3", |
|
"habitat_4", |
|
"habitat_5", |
|
"habitat_6", |
|
"habitat_7", |
|
"habitat_8", |
|
"habitat_9", |
|
"habitat_10", |
|
"habitat_11", |
|
"habitat_12", |
|
"habitat_13", |
|
"habitat_14", |
|
"habitat_15", |
|
"habitat_16", |
|
"habitat_17", |
|
"habitat_18", |
|
"habitat_19", |
|
"habitat_20", |
|
"habitat_21", |
|
"habitat_22", |
|
"habitat_23", |
|
"habitat_24", |
|
"habitat_25", |
|
"habitat_26", |
|
"habitat_27", |
|
"habitat_28", |
|
"habitat_29", |
|
"habitat_30", |
|
"habitat_31", |
|
] |
|
|
|
|
|
class EmbeddingMetadataDataset(Dataset): |
|
def __init__(self, df): |
|
self.df = df |
|
|
|
self.emb = df["embedding"] |
|
self.metadata_date = df[TIME].to_numpy() |
|
self.metadata_geo = df[GEO].to_numpy() |
|
self.metadata_substrate = df[SUBSTRATE].to_numpy() |
|
|
|
def __len__(self): |
|
return len(self.df) |
|
|
|
def __getitem__(self, idx): |
|
embedding = torch.Tensor(self.emb[idx].copy()).type(torch.float) |
|
|
|
metadata = { |
|
"date": torch.from_numpy(self.metadata_date[idx, :]).type(torch.float), |
|
"geo": torch.from_numpy(self.metadata_geo[idx, :]).type(torch.float), |
|
"substr": torch.from_numpy(self.metadata_substrate[idx, :]).type( |
|
torch.float |
|
), |
|
} |
|
|
|
return embedding, metadata |
|
|
|
|
|
class ImageMetadataDataset(Dataset): |
|
def __init__(self, df, transform=None, local_filepath=None): |
|
self.df = df |
|
self.transform = transform |
|
self.local_filepath = local_filepath |
|
|
|
self.filepaths = ( |
|
df["image_path"].to_list() |
|
) |
|
self.metadata_date = df[TIME].to_numpy() |
|
self.metadata_geo = df[GEO].to_numpy() |
|
self.metadata_substrate = df[SUBSTRATE].to_numpy() |
|
|
|
def __len__(self): |
|
return len(self.df) |
|
|
|
def __getitem__(self, idx): |
|
file_path = os.path.join(self.local_filepath, self.filepaths[idx]) |
|
|
|
try: |
|
image = cv2.imread(file_path) |
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
except: |
|
print(file_path) |
|
|
|
if self.transform: |
|
augmented = self.transform(image=image) |
|
image = augmented["image"] |
|
|
|
metadata = { |
|
"date": torch.from_numpy(self.metadata_date[idx, :]).type(torch.float), |
|
"geo": torch.from_numpy(self.metadata_geo[idx, :]).type(torch.float), |
|
"substr": torch.from_numpy(self.metadata_substrate[idx, :]).type( |
|
torch.float |
|
), |
|
} |
|
|
|
return image, metadata |
|
|
|
|
|
DATE_SIZE = 4 |
|
GEO_SIZE = 7 |
|
SUBSTRATE_SIZE = 73 |
|
NUM_CLASSES = 1717 |
|
|
|
|
|
class StarReLU(nn.Module): |
|
""" |
|
StarReLU: s * relu(x) ** 2 + b |
|
""" |
|
|
|
def __init__( |
|
self, |
|
scale_value=1.0, |
|
bias_value=0.0, |
|
scale_learnable=True, |
|
bias_learnable=True, |
|
mode=None, |
|
inplace=False, |
|
): |
|
super().__init__() |
|
self.inplace = inplace |
|
self.relu = nn.ReLU(inplace=inplace) |
|
self.scale = nn.Parameter( |
|
scale_value * torch.ones(1), requires_grad=scale_learnable |
|
) |
|
self.bias = nn.Parameter( |
|
bias_value * torch.ones(1), requires_grad=bias_learnable |
|
) |
|
|
|
def forward(self, x): |
|
return self.scale * self.relu(x) ** 2 + self.bias |
|
|
|
|
|
class FungiMEEModel(nn.Module): |
|
def __init__( |
|
self, |
|
num_classes=NUM_CLASSES, |
|
dim=1024, |
|
): |
|
super().__init__() |
|
|
|
print("Setting up Pytorch Model") |
|
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
print(f"Using devide: {self.device}") |
|
|
|
self.date_embedding = MlpHead( |
|
dim=DATE_SIZE, num_classes=dim, mlp_ratio=128, act_layer=StarReLU |
|
) |
|
self.geo_embedding = MlpHead( |
|
dim=GEO_SIZE, num_classes=dim, mlp_ratio=128, act_layer=StarReLU |
|
) |
|
self.substr_embedding = MlpHead( |
|
dim=SUBSTRATE_SIZE, |
|
num_classes=dim, |
|
mlp_ratio=8, |
|
act_layer=StarReLU, |
|
) |
|
|
|
self.encoder = nn.TransformerEncoder( |
|
nn.TransformerEncoderLayer(d_model=dim, nhead=8, batch_first=True), |
|
num_layers=4, |
|
) |
|
|
|
self.head = MlpHead(dim=dim, num_classes=num_classes, drop_rate=0) |
|
|
|
for param in self.parameters(): |
|
if param.dim() > 1: |
|
nn.init.kaiming_normal_(param) |
|
|
|
def forward(self, img_emb, metadata): |
|
|
|
img_emb = img_emb.to(self.device) |
|
|
|
date_emb = self.date_embedding.forward(metadata["date"].to(self.device)) |
|
geo_emb = self.geo_embedding.forward(metadata["geo"].to(self.device)) |
|
substr_emb = self.substr_embedding.forward(metadata["substr"].to(self.device)) |
|
|
|
full_emb = torch.stack( |
|
(img_emb, date_emb, geo_emb, substr_emb), dim=1 |
|
) |
|
|
|
|
|
cls_emb = self.encoder.forward(full_emb)[:, 0, :].squeeze(1) |
|
|
|
return self.head.forward(cls_emb) |
|
|
|
def predict(self, img_emb, metadata): |
|
|
|
logits = self.forward(img_emb, metadata) |
|
|
|
|
|
|
|
return logits.argmax(1).tolist() |
|
|
|
|
|
class FungiEnsembleModel(nn.Module): |
|
|
|
def __init__(self, models, softmax=True) -> None: |
|
super().__init__() |
|
|
|
self.models = nn.ModuleList() |
|
self.softmax = softmax |
|
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
for model in models: |
|
model = model.to(self.device) |
|
model.eval() |
|
self.models.append(model) |
|
|
|
def forward(self, img_emb, metadata): |
|
|
|
img_emb = img_emb.to(self.device) |
|
|
|
probs = [] |
|
|
|
for model in self.models: |
|
logits = model.forward(img_emb, metadata) |
|
|
|
p = ( |
|
logits.softmax(dim=1).detach().cpu() |
|
if self.softmax |
|
else logits.detach().cpu() |
|
) |
|
|
|
probs.append(p) |
|
|
|
return torch.stack(probs).mean(dim=0) |
|
|
|
def predict(self, img_emb, metadata): |
|
|
|
logits = self.forward(img_emb, metadata) |
|
|
|
|
|
|
|
return logits.argmax(1).tolist() |
|
|
|
|
|
def is_gpu_available(): |
|
"""Check if the python package `onnxruntime-gpu` is installed.""" |
|
return torch.cuda.is_available() |
|
|
|
|
|
class PytorchWorker: |
|
"""Run inference using ONNX runtime.""" |
|
|
|
def __init__( |
|
self, model_path: str, model_name: str, number_of_categories: int = 1605 |
|
): |
|
|
|
def _load_model(model_name, model_path): |
|
|
|
print("Setting up Pytorch Model") |
|
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
print(f"Using devide: {self.device}") |
|
|
|
model = timm.create_model(model_name, num_classes=0, pretrained=False) |
|
|
|
|
|
|
|
return model.to(self.device).eval() |
|
|
|
self.model = _load_model(model_name, model_path) |
|
|
|
self.transforms = T.Compose( |
|
[ |
|
T.Resize((518, 518)), |
|
T.ToTensor(), |
|
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), |
|
] |
|
) |
|
|
|
def predict_image(self, image: np.ndarray): |
|
"""Run inference using ONNX runtime. |
|
|
|
:param image: Input image as numpy array. |
|
:return: A list with logits and confidences. |
|
""" |
|
|
|
self.model(self.transforms(image).unsqueeze(0).to(self.device)) |
|
|
|
return [-1] |
|
|
|
|
|
def make_submission(metadata_df, model_names=None): |
|
|
|
OUTPUT_CSV_PATH = "./submission.csv" |
|
|
|
"""Make submission with given """ |
|
|
|
BASE_CKPT_PATH = "./checkpoints" |
|
|
|
model_names = model_names or os.listdir(BASE_CKPT_PATH) |
|
|
|
models = [] |
|
|
|
for model_path in model_names: |
|
print("loading ", model_path) |
|
ckpt_path = os.path.join(BASE_CKPT_PATH, model_path) |
|
|
|
ckpt = torch.load(ckpt_path) |
|
model = FungiMEEModel() |
|
model.load_state_dict( |
|
{w: ckpt["state_dict"]["model." + w] for w in model.state_dict().keys()} |
|
) |
|
model.eval() |
|
model.cuda() |
|
|
|
models.append(model) |
|
|
|
ensemble_model = FungiEnsembleModel(models) |
|
|
|
embedding_dataset = EmbeddingMetadataDataset(metadata_df) |
|
loader = DataLoader(embedding_dataset, batch_size=128, shuffle=False) |
|
|
|
preds = [] |
|
for data in tqdm(loader): |
|
emb, metadata = data |
|
pred = ensemble_model.forward(emb, metadata) |
|
preds.append(pred) |
|
|
|
all_preds = torch.vstack(preds).numpy() |
|
|
|
preds_df = metadata_df[["observation_id", "image_path"]] |
|
preds_df["preds"] = [i for i in all_preds] |
|
preds_df = ( |
|
preds_df[["observation_id", "preds"]] |
|
.groupby("observation_id") |
|
.mean() |
|
.reset_index() |
|
) |
|
preds_df["class_id"] = preds_df["preds"].apply( |
|
lambda x: x.argmax() if x.argmax() <= 1603 else -1 |
|
) |
|
preds_df[["observation_id", "class_id"]].to_csv(OUTPUT_CSV_PATH, index=None) |
|
|
|
print("Submission complete") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
MODEL_PATH = "metaformer-s-224.pth" |
|
MODEL_NAME = "timm/vit_base_patch14_reg4_dinov2.lvd142m" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metadata_file_path = "../trial_submission.csv" |
|
root_dir = "../data/DF_FULL" |
|
|
|
|
|
|
|
metadata_df = generate_embeddings(metadata_file_path, root_dir) |
|
|
|
make_submission(metadata_df) |
|
|