DSGT_FungiClef / script.py
chychiu's picture
Update script.py
2e47c02 verified
import os
import numpy as np
import pandas as pd
import timm
import torch
import torch.nn as nn
from PIL import Image
from timm.models.metaformer import MlpHead
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from albumentations import Compose, Normalize, Resize
from albumentations.pytorch import ToTensorV2
import cv2
DIM = 518
DATE_SIZE = 4
GEO_SIZE = 7
SUBSTRATE_SIZE = 73
NUM_CLASSES = 1717
TIME = ["m0", "m1", "d0", "d1"]
GEO = ["g0", "g1", "g2", "g3", "g4", "g5", "g_float"]
SUBSTRATE = [
"substrate_0",
"substrate_1",
"substrate_2",
"substrate_3",
"substrate_4",
"substrate_5",
"substrate_6",
"substrate_7",
"substrate_8",
"substrate_9",
"substrate_10",
"substrate_11",
"substrate_12",
"substrate_13",
"substrate_14",
"substrate_15",
"substrate_16",
"substrate_17",
"substrate_18",
"substrate_19",
"substrate_20",
"substrate_21",
"substrate_22",
"substrate_23",
"substrate_24",
"substrate_25",
"substrate_26",
"substrate_27",
"substrate_28",
"substrate_29",
"substrate_30",
"metasubstrate_0",
"metasubstrate_1",
"metasubstrate_2",
"metasubstrate_3",
"metasubstrate_4",
"metasubstrate_5",
"metasubstrate_6",
"metasubstrate_7",
"metasubstrate_8",
"metasubstrate_9",
"habitat_0",
"habitat_1",
"habitat_2",
"habitat_3",
"habitat_4",
"habitat_5",
"habitat_6",
"habitat_7",
"habitat_8",
"habitat_9",
"habitat_10",
"habitat_11",
"habitat_12",
"habitat_13",
"habitat_14",
"habitat_15",
"habitat_16",
"habitat_17",
"habitat_18",
"habitat_19",
"habitat_20",
"habitat_21",
"habitat_22",
"habitat_23",
"habitat_24",
"habitat_25",
"habitat_26",
"habitat_27",
"habitat_28",
"habitat_29",
"habitat_30",
"habitat_31",
]
class ImageDataset(Dataset):
def __init__(self, df, local_filepath):
self.df = df
self.transform = Compose(
[
Resize(DIM, DIM),
Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
ToTensorV2(),
]
)
self.local_filepath = local_filepath
self.filepaths = df["image_path"].to_list()
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
image_path = os.path.join(self.local_filepath, self.filepaths[idx])
# print("Reading from ", image_path)
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return self.transform(image=image)["image"]
class EmbeddingMetadataDataset(Dataset):
def __init__(self, df):
self.df = df
self.emb = df["embedding"]
self.metadata_date = df[TIME].to_numpy()
self.metadata_geo = df[GEO].to_numpy()
self.metadata_substrate = df[SUBSTRATE].to_numpy()
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
embedding = torch.Tensor(self.emb[idx].copy()).type(torch.float)
metadata = {
"date": torch.from_numpy(self.metadata_date[idx, :]).type(torch.float),
"geo": torch.from_numpy(self.metadata_geo[idx, :]).type(torch.float),
"substr": torch.from_numpy(self.metadata_substrate[idx, :]).type(
torch.float
),
}
return embedding, metadata
def generate_embeddings(metadata_file_path, root_dir):
DINOV2_CKPT = "./checkpoints/dinov2.bin"
metadata_df = pd.read_csv(metadata_file_path)
test_dataset = ImageDataset(metadata_df, local_filepath=root_dir)
loader = DataLoader(test_dataset, batch_size=3, shuffle=False)
device = torch.device('cpu')
model = timm.create_model(
"timm/vit_large_patch14_reg4_dinov2.lvd142m", pretrained=False
)
weights = torch.load(DINOV2_CKPT)
model.load_state_dict(weights)
model = model.to(device)
model.eval()
all_embs = []
for img in tqdm(loader):
img = img.to(device)
emb = model.forward(img)
all_embs.append(emb.detach().cpu().numpy())
all_embs = np.vstack(all_embs)
embs_list = [x for x in all_embs]
metadata_df["embedding"] = embs_list
return metadata_df
class StarReLU(nn.Module):
"""
StarReLU: s * relu(x) ** 2 + b
"""
def __init__(
self,
scale_value=1.0,
bias_value=0.0,
scale_learnable=True,
bias_learnable=True,
mode=None,
inplace=False,
):
super().__init__()
self.inplace = inplace
self.relu = nn.ReLU(inplace=inplace)
self.scale = nn.Parameter(
scale_value * torch.ones(1), requires_grad=scale_learnable
)
self.bias = nn.Parameter(
bias_value * torch.ones(1), requires_grad=bias_learnable
)
def forward(self, x):
return self.scale * self.relu(x) ** 2 + self.bias
class FungiMEEModel(nn.Module):
def __init__(
self,
num_classes=NUM_CLASSES,
dim=1024,
):
super().__init__()
print("Setting up Pytorch Model")
self.device = torch.device('cpu')
print(f"Using devide: {self.device}")
self.date_embedding = MlpHead(
dim=DATE_SIZE, num_classes=dim, mlp_ratio=128, act_layer=StarReLU
)
self.geo_embedding = MlpHead(
dim=GEO_SIZE, num_classes=dim, mlp_ratio=128, act_layer=StarReLU
)
self.substr_embedding = MlpHead(
dim=SUBSTRATE_SIZE,
num_classes=dim,
mlp_ratio=8,
act_layer=StarReLU,
)
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=dim, nhead=8, batch_first=True),
num_layers=4,
)
self.head = MlpHead(dim=dim, num_classes=num_classes, drop_rate=0)
for param in self.parameters():
if param.dim() > 1:
nn.init.kaiming_normal_(param)
def forward(self, img_emb, metadata):
img_emb = img_emb.to(self.device)
date_emb = self.date_embedding.forward(metadata["date"].to(self.device))
geo_emb = self.geo_embedding.forward(metadata["geo"].to(self.device))
substr_emb = self.substr_embedding.forward(metadata["substr"].to(self.device))
full_emb = torch.stack((img_emb, date_emb, geo_emb, substr_emb), dim=1)
cls_emb = self.encoder.forward(full_emb)[:, 0, :].squeeze(1)
return self.head.forward(cls_emb)
def predict(self, img_emb, metadata):
logits = self.forward(img_emb, metadata)
return logits.argmax(1).tolist()
class FungiEnsembleModel(nn.Module):
def __init__(self, models) -> None:
super().__init__()
self.models = nn.ModuleList()
self.device = torch.device('cpu')
for model in models:
model = model.to(self.device)
model.eval()
self.models.append(model)
def forward(self, img_emb, metadata):
img_emb = img_emb.to(self.device)
probs = []
for model in self.models:
logits = model.forward(img_emb, metadata)
p = logits.softmax(dim=1).detach().cpu()
probs.append(p)
return torch.stack(probs).mean(dim=0)
def predict(self, img_emb, metadata):
logits = self.forward(img_emb, metadata)
# Any preprocess happens here
return logits.argmax(1).tolist()
def make_submission(metadata_df):
OUTPUT_CSV_PATH = "./submission.csv"
BASE_CKPT_PATH = "./checkpoints"
model_names = [
"dino_2_optuna_05242231.ckpt",
"dino_optuna_05241449.ckpt",
"dino_optuna_05241257.ckpt",
"dino_optuna_05241222.ckpt",
"dino_2_optuna_05242055.ckpt",
"dino_2_optuna_05242156.ckpt",
"dino_2_optuna_05242344.ckpt",
]
models = []
for model_path in model_names:
print("loading ", model_path)
ckpt_path = os.path.join(BASE_CKPT_PATH, model_path)
ckpt = torch.load(ckpt_path)
model = FungiMEEModel()
model.load_state_dict(
{w: ckpt["model." + w] for w in model.state_dict().keys()}
)
model.eval()
models.append(model)
fungi_model = FungiEnsembleModel(models)
# ckpt_path = os.path.join(BASE_CKPT_PATH, "dino_2_optuna_05242055.ckpt")
# fungi_model = FungiMEEModel()
# ckpt = torch.load(ckpt_path)
# fungi_model.load_state_dict(
# {w: ckpt["model." + w] for w in fungi_model.state_dict().keys()}
# )
embedding_dataset = EmbeddingMetadataDataset(metadata_df)
loader = DataLoader(embedding_dataset, batch_size=128, shuffle=False)
preds = []
for data in tqdm(loader):
emb, metadata = data
pred = fungi_model.forward(emb, metadata)
preds.append(pred)
all_preds = torch.vstack(preds).numpy()
preds_df = metadata_df[["observation_id", "image_path"]]
preds_df["preds"] = [i for i in all_preds]
preds_df = (
preds_df[["observation_id", "preds"]]
.groupby("observation_id")
.mean()
.reset_index()
)
preds_df["class_id"] = preds_df["preds"].apply(
lambda x: x.argmax() if x.argmax() <= 1603 else -1
)
preds_df[["observation_id", "class_id"]].to_csv(OUTPUT_CSV_PATH, index=None)
print("Submission complete")
if __name__ == "__main__":
# # # # # # Real submission
import zipfile
with zipfile.ZipFile("/tmp/data/private_testset.zip", "r") as zip_ref:
zip_ref.extractall("/tmp/data/")
metadata_file_path = "./_test_preprocessed.csv"
root_dir = "/tmp/data/private_testset"
# Test submission
# metadata_file_path = "../trial_submission.csv"
# root_dir = "../data/DF_FULL"
##############
metadata_df = generate_embeddings(metadata_file_path, root_dir)
make_submission(metadata_df)