|
import os |
|
import numpy as np |
|
import pandas as pd |
|
import timm |
|
import torch |
|
import torch.nn as nn |
|
from PIL import Image |
|
from timm.models.metaformer import MlpHead |
|
from torch.utils.data import DataLoader, Dataset |
|
from tqdm import tqdm |
|
from albumentations import Compose, Normalize, Resize |
|
from albumentations.pytorch import ToTensorV2 |
|
import cv2 |
|
|
|
DIM = 518 |
|
DATE_SIZE = 4 |
|
GEO_SIZE = 7 |
|
SUBSTRATE_SIZE = 73 |
|
NUM_CLASSES = 1717 |
|
|
|
TIME = ["m0", "m1", "d0", "d1"] |
|
GEO = ["g0", "g1", "g2", "g3", "g4", "g5", "g_float"] |
|
SUBSTRATE = [ |
|
"substrate_0", |
|
"substrate_1", |
|
"substrate_2", |
|
"substrate_3", |
|
"substrate_4", |
|
"substrate_5", |
|
"substrate_6", |
|
"substrate_7", |
|
"substrate_8", |
|
"substrate_9", |
|
"substrate_10", |
|
"substrate_11", |
|
"substrate_12", |
|
"substrate_13", |
|
"substrate_14", |
|
"substrate_15", |
|
"substrate_16", |
|
"substrate_17", |
|
"substrate_18", |
|
"substrate_19", |
|
"substrate_20", |
|
"substrate_21", |
|
"substrate_22", |
|
"substrate_23", |
|
"substrate_24", |
|
"substrate_25", |
|
"substrate_26", |
|
"substrate_27", |
|
"substrate_28", |
|
"substrate_29", |
|
"substrate_30", |
|
"metasubstrate_0", |
|
"metasubstrate_1", |
|
"metasubstrate_2", |
|
"metasubstrate_3", |
|
"metasubstrate_4", |
|
"metasubstrate_5", |
|
"metasubstrate_6", |
|
"metasubstrate_7", |
|
"metasubstrate_8", |
|
"metasubstrate_9", |
|
"habitat_0", |
|
"habitat_1", |
|
"habitat_2", |
|
"habitat_3", |
|
"habitat_4", |
|
"habitat_5", |
|
"habitat_6", |
|
"habitat_7", |
|
"habitat_8", |
|
"habitat_9", |
|
"habitat_10", |
|
"habitat_11", |
|
"habitat_12", |
|
"habitat_13", |
|
"habitat_14", |
|
"habitat_15", |
|
"habitat_16", |
|
"habitat_17", |
|
"habitat_18", |
|
"habitat_19", |
|
"habitat_20", |
|
"habitat_21", |
|
"habitat_22", |
|
"habitat_23", |
|
"habitat_24", |
|
"habitat_25", |
|
"habitat_26", |
|
"habitat_27", |
|
"habitat_28", |
|
"habitat_29", |
|
"habitat_30", |
|
"habitat_31", |
|
] |
|
|
|
|
|
class ImageDataset(Dataset): |
|
def __init__(self, df, local_filepath): |
|
self.df = df |
|
self.transform = Compose( |
|
[ |
|
Resize(DIM, DIM), |
|
Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), |
|
ToTensorV2(), |
|
] |
|
) |
|
|
|
self.local_filepath = local_filepath |
|
|
|
self.filepaths = df["image_path"].to_list() |
|
|
|
def __len__(self): |
|
return len(self.df) |
|
|
|
def __getitem__(self, idx): |
|
image_path = os.path.join(self.local_filepath, self.filepaths[idx]) |
|
|
|
|
|
image = cv2.imread(image_path) |
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
|
return self.transform(image=image)["image"] |
|
|
|
|
|
class EmbeddingMetadataDataset(Dataset): |
|
def __init__(self, df): |
|
self.df = df |
|
|
|
self.emb = df["embedding"] |
|
self.metadata_date = df[TIME].to_numpy() |
|
self.metadata_geo = df[GEO].to_numpy() |
|
self.metadata_substrate = df[SUBSTRATE].to_numpy() |
|
|
|
def __len__(self): |
|
return len(self.df) |
|
|
|
def __getitem__(self, idx): |
|
embedding = torch.Tensor(self.emb[idx].copy()).type(torch.float) |
|
|
|
metadata = { |
|
"date": torch.from_numpy(self.metadata_date[idx, :]).type(torch.float), |
|
"geo": torch.from_numpy(self.metadata_geo[idx, :]).type(torch.float), |
|
"substr": torch.from_numpy(self.metadata_substrate[idx, :]).type( |
|
torch.float |
|
), |
|
} |
|
|
|
return embedding, metadata |
|
|
|
|
|
def generate_embeddings(metadata_file_path, root_dir): |
|
|
|
DINOV2_CKPT = "./checkpoints/dinov2.bin" |
|
|
|
metadata_df = pd.read_csv(metadata_file_path) |
|
|
|
test_dataset = ImageDataset(metadata_df, local_filepath=root_dir) |
|
|
|
loader = DataLoader(test_dataset, batch_size=3, shuffle=False) |
|
|
|
device = torch.device('cpu') |
|
model = timm.create_model( |
|
"timm/vit_large_patch14_reg4_dinov2.lvd142m", pretrained=False |
|
) |
|
weights = torch.load(DINOV2_CKPT) |
|
model.load_state_dict(weights) |
|
|
|
model = model.to(device) |
|
model.eval() |
|
|
|
all_embs = [] |
|
for img in tqdm(loader): |
|
|
|
img = img.to(device) |
|
|
|
emb = model.forward(img) |
|
|
|
all_embs.append(emb.detach().cpu().numpy()) |
|
|
|
all_embs = np.vstack(all_embs) |
|
|
|
embs_list = [x for x in all_embs] |
|
metadata_df["embedding"] = embs_list |
|
|
|
return metadata_df |
|
|
|
|
|
class StarReLU(nn.Module): |
|
""" |
|
StarReLU: s * relu(x) ** 2 + b |
|
""" |
|
|
|
def __init__( |
|
self, |
|
scale_value=1.0, |
|
bias_value=0.0, |
|
scale_learnable=True, |
|
bias_learnable=True, |
|
mode=None, |
|
inplace=False, |
|
): |
|
super().__init__() |
|
self.inplace = inplace |
|
self.relu = nn.ReLU(inplace=inplace) |
|
self.scale = nn.Parameter( |
|
scale_value * torch.ones(1), requires_grad=scale_learnable |
|
) |
|
self.bias = nn.Parameter( |
|
bias_value * torch.ones(1), requires_grad=bias_learnable |
|
) |
|
|
|
def forward(self, x): |
|
return self.scale * self.relu(x) ** 2 + self.bias |
|
|
|
|
|
class FungiMEEModel(nn.Module): |
|
def __init__( |
|
self, |
|
num_classes=NUM_CLASSES, |
|
dim=1024, |
|
): |
|
super().__init__() |
|
|
|
print("Setting up Pytorch Model") |
|
self.device = torch.device('cpu') |
|
print(f"Using devide: {self.device}") |
|
|
|
self.date_embedding = MlpHead( |
|
dim=DATE_SIZE, num_classes=dim, mlp_ratio=128, act_layer=StarReLU |
|
) |
|
self.geo_embedding = MlpHead( |
|
dim=GEO_SIZE, num_classes=dim, mlp_ratio=128, act_layer=StarReLU |
|
) |
|
self.substr_embedding = MlpHead( |
|
dim=SUBSTRATE_SIZE, |
|
num_classes=dim, |
|
mlp_ratio=8, |
|
act_layer=StarReLU, |
|
) |
|
|
|
self.encoder = nn.TransformerEncoder( |
|
nn.TransformerEncoderLayer(d_model=dim, nhead=8, batch_first=True), |
|
num_layers=4, |
|
) |
|
|
|
self.head = MlpHead(dim=dim, num_classes=num_classes, drop_rate=0) |
|
|
|
for param in self.parameters(): |
|
if param.dim() > 1: |
|
nn.init.kaiming_normal_(param) |
|
|
|
def forward(self, img_emb, metadata): |
|
|
|
img_emb = img_emb.to(self.device) |
|
|
|
date_emb = self.date_embedding.forward(metadata["date"].to(self.device)) |
|
geo_emb = self.geo_embedding.forward(metadata["geo"].to(self.device)) |
|
substr_emb = self.substr_embedding.forward(metadata["substr"].to(self.device)) |
|
|
|
full_emb = torch.stack((img_emb, date_emb, geo_emb, substr_emb), dim=1) |
|
|
|
cls_emb = self.encoder.forward(full_emb)[:, 0, :].squeeze(1) |
|
|
|
return self.head.forward(cls_emb) |
|
|
|
def predict(self, img_emb, metadata): |
|
|
|
logits = self.forward(img_emb, metadata) |
|
|
|
return logits.argmax(1).tolist() |
|
|
|
|
|
class FungiEnsembleModel(nn.Module): |
|
|
|
def __init__(self, models) -> None: |
|
super().__init__() |
|
|
|
self.models = nn.ModuleList() |
|
self.device = torch.device('cpu') |
|
|
|
for model in models: |
|
model = model.to(self.device) |
|
model.eval() |
|
self.models.append(model) |
|
|
|
def forward(self, img_emb, metadata): |
|
|
|
img_emb = img_emb.to(self.device) |
|
|
|
probs = [] |
|
|
|
for model in self.models: |
|
logits = model.forward(img_emb, metadata) |
|
|
|
p = logits.softmax(dim=1).detach().cpu() |
|
probs.append(p) |
|
|
|
return torch.stack(probs).mean(dim=0) |
|
|
|
def predict(self, img_emb, metadata): |
|
|
|
logits = self.forward(img_emb, metadata) |
|
|
|
|
|
|
|
return logits.argmax(1).tolist() |
|
|
|
|
|
def make_submission(metadata_df): |
|
|
|
OUTPUT_CSV_PATH = "./submission.csv" |
|
BASE_CKPT_PATH = "./checkpoints" |
|
|
|
model_names = [ |
|
"dino_2_optuna_05242231.ckpt", |
|
"dino_optuna_05241449.ckpt", |
|
"dino_optuna_05241257.ckpt", |
|
"dino_optuna_05241222.ckpt", |
|
"dino_2_optuna_05242055.ckpt", |
|
"dino_2_optuna_05242156.ckpt", |
|
"dino_2_optuna_05242344.ckpt", |
|
] |
|
|
|
models = [] |
|
|
|
for model_path in model_names: |
|
print("loading ", model_path) |
|
ckpt_path = os.path.join(BASE_CKPT_PATH, model_path) |
|
|
|
ckpt = torch.load(ckpt_path) |
|
model = FungiMEEModel() |
|
model.load_state_dict( |
|
{w: ckpt["model." + w] for w in model.state_dict().keys()} |
|
) |
|
model.eval() |
|
|
|
models.append(model) |
|
|
|
fungi_model = FungiEnsembleModel(models) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
embedding_dataset = EmbeddingMetadataDataset(metadata_df) |
|
loader = DataLoader(embedding_dataset, batch_size=128, shuffle=False) |
|
|
|
preds = [] |
|
for data in tqdm(loader): |
|
emb, metadata = data |
|
pred = fungi_model.forward(emb, metadata) |
|
preds.append(pred) |
|
|
|
all_preds = torch.vstack(preds).numpy() |
|
|
|
preds_df = metadata_df[["observation_id", "image_path"]] |
|
preds_df["preds"] = [i for i in all_preds] |
|
preds_df = ( |
|
preds_df[["observation_id", "preds"]] |
|
.groupby("observation_id") |
|
.mean() |
|
.reset_index() |
|
) |
|
preds_df["class_id"] = preds_df["preds"].apply( |
|
lambda x: x.argmax() if x.argmax() <= 1603 else -1 |
|
) |
|
preds_df[["observation_id", "class_id"]].to_csv(OUTPUT_CSV_PATH, index=None) |
|
|
|
print("Submission complete") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
import zipfile |
|
|
|
with zipfile.ZipFile("/tmp/data/private_testset.zip", "r") as zip_ref: |
|
zip_ref.extractall("/tmp/data/") |
|
|
|
metadata_file_path = "./_test_preprocessed.csv" |
|
root_dir = "/tmp/data/private_testset" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metadata_df = generate_embeddings(metadata_file_path, root_dir) |
|
|
|
make_submission(metadata_df) |
|
|
|
|
|
|
|
|
|
|