import io |
import os |
from typing import List |
import cv2 |
import numpy as np |
import pandas as pd |
import timm |
import torch |
import torch.nn as nn |
import torch.nn.functional as F |
import torchvision.transforms as T |
from albumentations import (CenterCrop, Compose, HorizontalFlip, Normalize, |
PadIfNeeded, RandomBrightnessContrast, RandomCrop, |
RandomResizedCrop, Resize, VerticalFlip) |
from albumentations.pytorch import ToTensorV2 |
from PIL import Image |
from timm.layers import LayerNorm2d, SelectAdaptivePool2d |
from timm.models.metaformer import MlpHead |
from torch.utils.data import DataLoader, Dataset |
from tqdm import tqdm |
def get_transforms(*, data, model=None, width=None, height=None): |
assert data in ("train", "valid") |
width = width if width else DEFAULT_WIDTH |
height = height if height else DEFAULT_HEIGHT |
model_mean = list(model.default_cfg["mean"]) if model else (0.5, 0.5, 0.5) |
model_std = list(model.default_cfg["std"]) if model else (0.5, 0.5, 0.5) |
if data == "train": |
return Compose( |
[ |
RandomResizedCrop(width, height, scale=(0.6, 1.0)), |
HorizontalFlip(p=0.5), |
VerticalFlip(p=0.5), |
RandomBrightnessContrast(p=0.2), |
Normalize(mean=model_mean, std=model_std), |
ToTensorV2(), |
] |
) |
elif data == "valid": |
return Compose( |
[ |
Resize(width, height), |
Normalize(mean=model_mean, std=model_std), |
ToTensorV2(), |
] |
) |
DIM = 518 |
BASE_PATH = "../data/DF_FULL" |
def generate_embeddings(metadata_file_path, root_dir): |
metadata_df = pd.read_csv(metadata_file_path) |
transforms = get_transforms(data="valid", width=DIM, height=DIM) |
test_dataset = ImageMetadataDataset( |
metadata_df, local_filepath=root_dir, transform=transforms |
) |
loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4) |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
model = timm.create_model("timm/vit_large_patch14_reg4_dinov2.lvd142m", pretrained=True) |
model = model.to(device) |
model.eval() |
all_embs = [] |
for data in tqdm(loader): |
img, _ = data |
img = img.to(device) |
emb = model.forward(img) |
all_embs.append(emb.detach().cpu().numpy()) |
all_embs = np.vstack(all_embs) |
embs_list = [x for x in all_embs] |
metadata_df["embedding"] = embs_list |
return metadata_df |
TIME = ['m0', 'm1', 'd0', 'd1'] |
GEO = ['g0', 'g1', 'g2', 'g3', 'g4', 'g5', 'g_float'] |
SUBSTRATE = ["substrate_0", |
"substrate_1", |
"substrate_2", |
"substrate_3", |
"substrate_4", |
"substrate_5", |
"substrate_6", |
"substrate_7", |
"substrate_8", |
"substrate_9", |
"substrate_10", |
"substrate_11", |
"substrate_12", |
"substrate_13", |
"substrate_14", |
"substrate_15", |
"substrate_16", |
"substrate_17", |
"substrate_18", |
"substrate_19", |
"substrate_20", |
"substrate_21", |
"substrate_22", |
"substrate_23", |
"substrate_24", |
"substrate_25", |
"substrate_26", |
"substrate_27", |
"substrate_28", |
"substrate_29", |
"substrate_30", |
"metasubstrate_0", |
"metasubstrate_1", |
"metasubstrate_2", |
"metasubstrate_3", |
"metasubstrate_4", |
"metasubstrate_5", |
"metasubstrate_6", |
"metasubstrate_7", |
"metasubstrate_8", |
"metasubstrate_9", |
"habitat_0", |
"habitat_1", |
"habitat_2", |
"habitat_3", |
"habitat_4", |
"habitat_5", |
"habitat_6", |
"habitat_7", |
"habitat_8", |
"habitat_9", |
"habitat_10", |
"habitat_11", |
"habitat_12", |
"habitat_13", |
"habitat_14", |
"habitat_15", |
"habitat_16", |
"habitat_17", |
"habitat_18", |
"habitat_19", |
"habitat_20", |
"habitat_21", |
"habitat_22", |
"habitat_23", |
"habitat_24", |
"habitat_25", |
"habitat_26", |
"habitat_27", |
"habitat_28", |
"habitat_29", |
"habitat_30", |
"habitat_31", |
] |
class EmbeddingMetadataDataset(Dataset): |
def __init__(self, df): |
self.df = df |
self.emb = df['embedding'] |
self.metadata_date = df[TIME].to_numpy() |
self.metadata_geo = df[GEO].to_numpy() |
self.metadata_substrate = df[SUBSTRATE].to_numpy() |
def __len__(self): |
return len(self.df) |
def __getitem__(self, idx): |
embedding = torch.Tensor(self.emb[idx].copy()).type(torch.float) |
metadata = { |
"date": torch.from_numpy(self.metadata_date[idx, :]).type(torch.float), |
"geo": torch.from_numpy(self.metadata_geo[idx, :]).type(torch.float), |
"substr": torch.from_numpy(self.metadata_substrate[idx, :]).type(torch.float), |
} |
return embedding, metadata |
class ImageMetadataDataset(Dataset): |
def __init__(self, df, transform=None, local_filepath=None): |
self.df = df |
self.transform = transform |
self.local_filepath = local_filepath |
self.filepaths = df["image_path"].apply(lambda x: x.replace("jpg", "JPG")).to_list() |
self.metadata_date = df[TIME].to_numpy() |
self.metadata_geo = df[GEO].to_numpy() |
self.metadata_substrate = df[SUBSTRATE].to_numpy() |
def __len__(self): |
return len(self.df) |
def __getitem__(self, idx): |
file_path = os.path.join(self.local_filepath, self.filepaths[idx]) |
try: |
image = cv2.imread(file_path) |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
except: |
print(file_path) |
if self.transform: |
augmented = self.transform(image=image) |
image = augmented["image"] |
metadata = { |
"date": torch.from_numpy(self.metadata_date[idx, :]).type(torch.float), |
"geo": torch.from_numpy(self.metadata_geo[idx, :]).type(torch.float), |
"substr": torch.from_numpy(self.metadata_substrate[idx, :]).type(torch.float), |
} |
return image, metadata |
GEO_SIZE = 7 |
NUM_CLASSES = 1717 |
class StarReLU(nn.Module): |
""" |
StarReLU: s * relu(x) ** 2 + b |
""" |
def __init__( |
self, |
scale_value=1.0, |
bias_value=0.0, |
scale_learnable=True, |
bias_learnable=True, |
mode=None, |
inplace=False, |
): |
super().__init__() |
self.inplace = inplace |
self.relu = nn.ReLU(inplace=inplace) |
self.scale = nn.Parameter( |
scale_value * torch.ones(1), requires_grad=scale_learnable |
) |
self.bias = nn.Parameter( |
bias_value * torch.ones(1), requires_grad=bias_learnable |
) |
def forward(self, x): |
return self.scale * self.relu(x) ** 2 + self.bias |
class FungiMEEModel(nn.Module): |
def __init__( |
self, |
num_classes=NUM_CLASSES, |
dim=1024, |
): |
super().__init__() |
print("Setting up Pytorch Model") |
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
print(f"Using devide: {self.device}") |
self.date_embedding = MlpHead( |
dim=DATE_SIZE, num_classes=dim, mlp_ratio=128, act_layer=StarReLU |
) |
self.geo_embedding = MlpHead( |
dim=GEO_SIZE, num_classes=dim, mlp_ratio=128, act_layer=StarReLU |
) |
self.substr_embedding = MlpHead( |
num_classes=dim, |
mlp_ratio=8, |
act_layer=StarReLU, |
) |
self.encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=dim, nhead=8, batch_first=True), num_layers=4) |
self.head = MlpHead(dim=dim, num_classes=num_classes, drop_rate=0) |
for param in self.parameters(): |
if param.dim() > 1: |
nn.init.kaiming_normal_(param) |
def forward(self, img_emb, metadata): |
img_emb = img_emb.to(self.device) |
date_emb = self.date_embedding.forward(metadata["date"].to(self.device)) |
geo_emb = self.geo_embedding.forward(metadata["geo"].to(self.device)) |
substr_emb = self.substr_embedding.forward(metadata["substr"].to(self.device)) |
full_emb = torch.stack((img_emb, date_emb, geo_emb, substr_emb), dim=1) |
cls_emb = self.encoder.forward(full_emb)[:, 0, :].squeeze(1) |
return self.head.forward(cls_emb) |
def predict(self, img_emb, metadata): |
logits = self.forward(img_emb, metadata) |
return logits.argmax(1).tolist() |
class FungiEnsembleModel(nn.Module): |
def __init__(self, models, softmax=True) -> None: |
super().__init__() |
self.models = nn.ModuleList() |
self.softmax = softmax |
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
for model in models: |
model = model.to(self.device) |
model.eval() |
self.models.append(model) |
def forward(self, img_emb, metadata): |
img_emb = img_emb.to(self.device) |
probs = [] |
for model in self.models: |
logits = model.forward(img_emb, metadata) |
p = logits.softmax(dim=1).detach().cpu() if self.softmax else logits.detach().cpu() |
probs.append(p) |
return torch.stack(probs).mean(dim=0) |
def predict(self, img_emb, metadata): |
logits = self.forward(img_emb, metadata) |
return logits.argmax(1).tolist() |
def is_gpu_available(): |
"""Check if the python package `onnxruntime-gpu` is installed.""" |
return torch.cuda.is_available() |
class PytorchWorker: |
"""Run inference using ONNX runtime.""" |
def __init__(self, model_path: str, model_name: str, number_of_categories: int = 1605): |
def _load_model(model_name, model_path): |
print("Setting up Pytorch Model") |
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
print(f"Using devide: {self.device}") |
model = timm.create_model(model_name, num_classes=0, pretrained=False) |
return model.to(self.device).eval() |
self.model = _load_model(model_name, model_path) |
self.transforms = T.Compose([T.Resize((518, 518)), |
T.ToTensor(), |
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) |
def predict_image(self, image: np.ndarray): |
"""Run inference using ONNX runtime. |
:param image: Input image as numpy array. |
:return: A list with logits and confidences. |
""" |
self.model(self.transforms(image).unsqueeze(0).to(self.device)) |
return [-1] |
def make_submission(metadata_df, model_names=None): |
OUTPUT_CSV_PATH="./submission.csv" |
"""Make submission with given """ |
BASE_CKPT_PATH = "./checkpoints" |
model_names = model_names or os.listdir(BASE_CKPT_PATH) |
models = [] |
for model_path in model_names: |
print("loading ", model_path) |
ckpt_path = os.path.join(BASE_CKPT_PATH, model_path) |
ckpt = torch.load(ckpt_path) |
model = FungiMEEModel() |
model.load_state_dict({w: ckpt['state_dict']["model." + w] for w in model.state_dict().keys()}) |
model.eval() |
model.cuda() |
models.append(model) |
ensemble_model = FungiEnsembleModel(models) |
embedding_dataset = EmbeddingMetadataDataset(metadata_df) |
loader = DataLoader(embedding_dataset, batch_size=128, shuffle=False) |
preds = [] |
for data in tqdm(loader): |
emb, metadata = data |
pred = ensemble_model.forward(emb, metadata) |
preds.append(pred) |
all_preds = torch.vstack(preds).numpy() |
preds_df = metadata_df[['observation_id', 'image_path']] |
preds_df['preds'] = [i for i in all_preds] |
preds_df = preds_df[['observation_id', 'preds']].groupby('observation_id').mean().reset_index() |
preds_df['class_id'] = preds_df['preds'].apply(lambda x: x.argmax() if x.argmax() <= 1603 else -1) |
preds_df[['observation_id', 'class_id']].to_csv(OUTPUT_CSV_PATH, index=None) |
print("Submission complete") |
if __name__ == "__main__": |
MODEL_PATH = "metaformer-s-224.pth" |
MODEL_NAME = "timm/vit_base_patch14_reg4_dinov2.lvd142m" |
import zipfile |
with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref: |
zip_ref.extractall("/tmp/data") |
metadata_file_path = "./_test_preprocessed.csv" |
root_dir = "/tmp/data" |
metadata_df = generate_embeddings(metadata_file_path, root_dir) |
make_submission(metadata_df) |