DSGT_FungiClef / script.py
chychiu's picture
fixed script
8f1fb11
raw
history blame
13.2 kB
import os
from typing import List
import cv2
import numpy as np
import pandas as pd
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from albumentations import (
CenterCrop,
Compose,
HorizontalFlip,
Normalize,
PadIfNeeded,
RandomBrightnessContrast,
RandomCrop,
RandomResizedCrop,
Resize,
VerticalFlip,
)
from albumentations.pytorch import ToTensorV2
from PIL import Image
from timm.layers import LayerNorm2d, SelectAdaptivePool2d
from timm.models.metaformer import MlpHead
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
DIM = 518
def get_transforms(*, data, model=None, width=None, height=None):
assert data in ("train", "valid")
width = width if width else DIM
height = height if height else DIM
model_mean = list(model.default_cfg["mean"]) if model else (0.5, 0.5, 0.5)
model_std = list(model.default_cfg["std"]) if model else (0.5, 0.5, 0.5)
if data == "train":
return Compose(
[
RandomResizedCrop(width, height, scale=(0.6, 1.0)),
HorizontalFlip(p=0.5),
VerticalFlip(p=0.5),
RandomBrightnessContrast(p=0.2),
Normalize(mean=model_mean, std=model_std),
ToTensorV2(),
]
)
elif data == "valid":
return Compose(
[
Resize(width, height),
Normalize(mean=model_mean, std=model_std),
ToTensorV2(),
]
)
def generate_embeddings(metadata_file_path, root_dir):
metadata_df = pd.read_csv(metadata_file_path)
transforms = get_transforms(data="valid", width=DIM, height=DIM)
test_dataset = ImageMetadataDataset(
metadata_df, local_filepath=root_dir, transform=transforms
)
loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = timm.create_model(
"timm/vit_large_patch14_reg4_dinov2.lvd142m", pretrained=True
)
model = model.to(device)
model.eval()
all_embs = []
for data in tqdm(loader):
img, _ = data
img = img.to(device)
emb = model.forward(img)
all_embs.append(emb.detach().cpu().numpy())
all_embs = np.vstack(all_embs)
embs_list = [x for x in all_embs]
metadata_df["embedding"] = embs_list
return metadata_df
TIME = ["m0", "m1", "d0", "d1"]
GEO = ["g0", "g1", "g2", "g3", "g4", "g5", "g_float"]
SUBSTRATE = [
"substrate_0",
"substrate_1",
"substrate_2",
"substrate_3",
"substrate_4",
"substrate_5",
"substrate_6",
"substrate_7",
"substrate_8",
"substrate_9",
"substrate_10",
"substrate_11",
"substrate_12",
"substrate_13",
"substrate_14",
"substrate_15",
"substrate_16",
"substrate_17",
"substrate_18",
"substrate_19",
"substrate_20",
"substrate_21",
"substrate_22",
"substrate_23",
"substrate_24",
"substrate_25",
"substrate_26",
"substrate_27",
"substrate_28",
"substrate_29",
"substrate_30",
"metasubstrate_0",
"metasubstrate_1",
"metasubstrate_2",
"metasubstrate_3",
"metasubstrate_4",
"metasubstrate_5",
"metasubstrate_6",
"metasubstrate_7",
"metasubstrate_8",
"metasubstrate_9",
"habitat_0",
"habitat_1",
"habitat_2",
"habitat_3",
"habitat_4",
"habitat_5",
"habitat_6",
"habitat_7",
"habitat_8",
"habitat_9",
"habitat_10",
"habitat_11",
"habitat_12",
"habitat_13",
"habitat_14",
"habitat_15",
"habitat_16",
"habitat_17",
"habitat_18",
"habitat_19",
"habitat_20",
"habitat_21",
"habitat_22",
"habitat_23",
"habitat_24",
"habitat_25",
"habitat_26",
"habitat_27",
"habitat_28",
"habitat_29",
"habitat_30",
"habitat_31",
]
class EmbeddingMetadataDataset(Dataset):
def __init__(self, df):
self.df = df
self.emb = df["embedding"]
self.metadata_date = df[TIME].to_numpy()
self.metadata_geo = df[GEO].to_numpy()
self.metadata_substrate = df[SUBSTRATE].to_numpy()
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
embedding = torch.Tensor(self.emb[idx].copy()).type(torch.float)
metadata = {
"date": torch.from_numpy(self.metadata_date[idx, :]).type(torch.float),
"geo": torch.from_numpy(self.metadata_geo[idx, :]).type(torch.float),
"substr": torch.from_numpy(self.metadata_substrate[idx, :]).type(
torch.float
),
}
return embedding, metadata
class ImageMetadataDataset(Dataset):
def __init__(self, df, transform=None, local_filepath=None):
self.df = df
self.transform = transform
self.local_filepath = local_filepath
self.filepaths = (
df["image_path"].to_list()
)
self.metadata_date = df[TIME].to_numpy()
self.metadata_geo = df[GEO].to_numpy()
self.metadata_substrate = df[SUBSTRATE].to_numpy()
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
file_path = os.path.join(self.local_filepath, self.filepaths[idx])
try:
image = cv2.imread(file_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
except:
print(file_path)
if self.transform:
augmented = self.transform(image=image)
image = augmented["image"]
metadata = {
"date": torch.from_numpy(self.metadata_date[idx, :]).type(torch.float),
"geo": torch.from_numpy(self.metadata_geo[idx, :]).type(torch.float),
"substr": torch.from_numpy(self.metadata_substrate[idx, :]).type(
torch.float
),
}
return image, metadata
DATE_SIZE = 4
GEO_SIZE = 7
SUBSTRATE_SIZE = 73
NUM_CLASSES = 1717
class StarReLU(nn.Module):
"""
StarReLU: s * relu(x) ** 2 + b
"""
def __init__(
self,
scale_value=1.0,
bias_value=0.0,
scale_learnable=True,
bias_learnable=True,
mode=None,
inplace=False,
):
super().__init__()
self.inplace = inplace
self.relu = nn.ReLU(inplace=inplace)
self.scale = nn.Parameter(
scale_value * torch.ones(1), requires_grad=scale_learnable
)
self.bias = nn.Parameter(
bias_value * torch.ones(1), requires_grad=bias_learnable
)
def forward(self, x):
return self.scale * self.relu(x) ** 2 + self.bias
class FungiMEEModel(nn.Module):
def __init__(
self,
num_classes=NUM_CLASSES,
dim=1024,
):
super().__init__()
print("Setting up Pytorch Model")
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using devide: {self.device}")
self.date_embedding = MlpHead(
dim=DATE_SIZE, num_classes=dim, mlp_ratio=128, act_layer=StarReLU
)
self.geo_embedding = MlpHead(
dim=GEO_SIZE, num_classes=dim, mlp_ratio=128, act_layer=StarReLU
)
self.substr_embedding = MlpHead(
dim=SUBSTRATE_SIZE,
num_classes=dim,
mlp_ratio=8,
act_layer=StarReLU,
)
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=dim, nhead=8, batch_first=True),
num_layers=4,
)
self.head = MlpHead(dim=dim, num_classes=num_classes, drop_rate=0)
for param in self.parameters():
if param.dim() > 1:
nn.init.kaiming_normal_(param)
def forward(self, img_emb, metadata):
img_emb = img_emb.to(self.device)
date_emb = self.date_embedding.forward(metadata["date"].to(self.device))
geo_emb = self.geo_embedding.forward(metadata["geo"].to(self.device))
substr_emb = self.substr_embedding.forward(metadata["substr"].to(self.device))
full_emb = torch.stack(
(img_emb, date_emb, geo_emb, substr_emb), dim=1
) # .unsqueeze(0)
# print(full_emb.shape)
cls_emb = self.encoder.forward(full_emb)[:, 0, :].squeeze(1)
return self.head.forward(cls_emb)
def predict(self, img_emb, metadata):
logits = self.forward(img_emb, metadata)
# Any preprocess happens here
return logits.argmax(1).tolist()
class FungiEnsembleModel(nn.Module):
def __init__(self, models, softmax=True) -> None:
super().__init__()
self.models = nn.ModuleList()
self.softmax = softmax
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
for model in models:
model = model.to(self.device)
model.eval()
self.models.append(model)
def forward(self, img_emb, metadata):
img_emb = img_emb.to(self.device)
probs = []
for model in self.models:
logits = model.forward(img_emb, metadata)
p = (
logits.softmax(dim=1).detach().cpu()
if self.softmax
else logits.detach().cpu()
)
probs.append(p)
return torch.stack(probs).mean(dim=0)
def predict(self, img_emb, metadata):
logits = self.forward(img_emb, metadata)
# Any preprocess happens here
return logits.argmax(1).tolist()
def is_gpu_available():
"""Check if the python package `onnxruntime-gpu` is installed."""
return torch.cuda.is_available()
class PytorchWorker:
"""Run inference using ONNX runtime."""
def __init__(
self, model_path: str, model_name: str, number_of_categories: int = 1605
):
def _load_model(model_name, model_path):
print("Setting up Pytorch Model")
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using devide: {self.device}")
model = timm.create_model(model_name, num_classes=0, pretrained=False)
# weights = torch.load(model_path, map_location=self.device)
# model.load_state_dict({w.replace("model.", ""): v for w, v in weights.items()})
return model.to(self.device).eval()
self.model = _load_model(model_name, model_path)
self.transforms = T.Compose(
[
T.Resize((518, 518)),
T.ToTensor(),
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
def predict_image(self, image: np.ndarray):
"""Run inference using ONNX runtime.
:param image: Input image as numpy array.
:return: A list with logits and confidences.
"""
self.model(self.transforms(image).unsqueeze(0).to(self.device))
return [-1]
def make_submission(metadata_df, model_names=None):
OUTPUT_CSV_PATH = "./submission.csv"
"""Make submission with given """
BASE_CKPT_PATH = "./checkpoints"
model_names = model_names or os.listdir(BASE_CKPT_PATH)
models = []
for model_path in model_names:
print("loading ", model_path)
ckpt_path = os.path.join(BASE_CKPT_PATH, model_path)
ckpt = torch.load(ckpt_path)
model = FungiMEEModel()
model.load_state_dict(
{w: ckpt["state_dict"]["model." + w] for w in model.state_dict().keys()}
)
model.eval()
model.cuda()
models.append(model)
ensemble_model = FungiEnsembleModel(models)
embedding_dataset = EmbeddingMetadataDataset(metadata_df)
loader = DataLoader(embedding_dataset, batch_size=128, shuffle=False)
preds = []
for data in tqdm(loader):
emb, metadata = data
pred = ensemble_model.forward(emb, metadata)
preds.append(pred)
all_preds = torch.vstack(preds).numpy()
preds_df = metadata_df[["observation_id", "image_path"]]
preds_df["preds"] = [i for i in all_preds]
preds_df = (
preds_df[["observation_id", "preds"]]
.groupby("observation_id")
.mean()
.reset_index()
)
preds_df["class_id"] = preds_df["preds"].apply(
lambda x: x.argmax() if x.argmax() <= 1603 else -1
)
preds_df[["observation_id", "class_id"]].to_csv(OUTPUT_CSV_PATH, index=None)
print("Submission complete")
if __name__ == "__main__":
MODEL_PATH = "metaformer-s-224.pth"
MODEL_NAME = "timm/vit_base_patch14_reg4_dinov2.lvd142m"
# # Real submission
# import zipfile
# with zipfile.ZipFile("/tmp/data/private_testset.zip", "r") as zip_ref:
# zip_ref.extractall("/tmp/data")
# metadata_file_path = "./_test_preprocessed.csv"
# root_dir = "/tmp/data"
# Test submission
metadata_file_path = "../trial_submission.csv"
root_dir = "../data/DF_FULL"
##############
metadata_df = generate_embeddings(metadata_file_path, root_dir)
make_submission(metadata_df)