import os import numpy as np import pandas as pd import timm import torch import torch.nn as nn from PIL import Image from timm.models.metaformer import MlpHead from torch.utils.data import DataLoader, Dataset from tqdm import tqdm from albumentations import Compose, Normalize, Resize from albumentations.pytorch import ToTensorV2 import cv2 DIM = 518 DATE_SIZE = 4 GEO_SIZE = 7 SUBSTRATE_SIZE = 73 NUM_CLASSES = 1717 TIME = ["m0", "m1", "d0", "d1"] GEO = ["g0", "g1", "g2", "g3", "g4", "g5", "g_float"] SUBSTRATE = [ "substrate_0", "substrate_1", "substrate_2", "substrate_3", "substrate_4", "substrate_5", "substrate_6", "substrate_7", "substrate_8", "substrate_9", "substrate_10", "substrate_11", "substrate_12", "substrate_13", "substrate_14", "substrate_15", "substrate_16", "substrate_17", "substrate_18", "substrate_19", "substrate_20", "substrate_21", "substrate_22", "substrate_23", "substrate_24", "substrate_25", "substrate_26", "substrate_27", "substrate_28", "substrate_29", "substrate_30", "metasubstrate_0", "metasubstrate_1", "metasubstrate_2", "metasubstrate_3", "metasubstrate_4", "metasubstrate_5", "metasubstrate_6", "metasubstrate_7", "metasubstrate_8", "metasubstrate_9", "habitat_0", "habitat_1", "habitat_2", "habitat_3", "habitat_4", "habitat_5", "habitat_6", "habitat_7", "habitat_8", "habitat_9", "habitat_10", "habitat_11", "habitat_12", "habitat_13", "habitat_14", "habitat_15", "habitat_16", "habitat_17", "habitat_18", "habitat_19", "habitat_20", "habitat_21", "habitat_22", "habitat_23", "habitat_24", "habitat_25", "habitat_26", "habitat_27", "habitat_28", "habitat_29", "habitat_30", "habitat_31", ] class ImageDataset(Dataset): def __init__(self, df, local_filepath): self.df = df self.transform = Compose( [ Resize(DIM, DIM), Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), ToTensorV2(), ] ) self.local_filepath = local_filepath self.filepaths = df["image_path"].to_list() def __len__(self): return len(self.df) def __getitem__(self, idx): image_path = os.path.join(self.local_filepath, self.filepaths[idx]) # print("Reading from ", image_path) image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) return self.transform(image=image)["image"] class EmbeddingMetadataDataset(Dataset): def __init__(self, df): self.df = df self.emb = df["embedding"] self.metadata_date = df[TIME].to_numpy() self.metadata_geo = df[GEO].to_numpy() self.metadata_substrate = df[SUBSTRATE].to_numpy() def __len__(self): return len(self.df) def __getitem__(self, idx): embedding = torch.Tensor(self.emb[idx].copy()).type(torch.float) metadata = { "date": torch.from_numpy(self.metadata_date[idx, :]).type(torch.float), "geo": torch.from_numpy(self.metadata_geo[idx, :]).type(torch.float), "substr": torch.from_numpy(self.metadata_substrate[idx, :]).type( torch.float ), } return embedding, metadata def generate_embeddings(metadata_file_path, root_dir): DINOV2_CKPT = "./checkpoints/dinov2.bin" metadata_df = pd.read_csv(metadata_file_path) test_dataset = ImageDataset(metadata_df, local_filepath=root_dir) loader = DataLoader(test_dataset, batch_size=3, shuffle=False) device = torch.device('cpu') model = timm.create_model( "timm/vit_large_patch14_reg4_dinov2.lvd142m", pretrained=False ) weights = torch.load(DINOV2_CKPT) model.load_state_dict(weights) model = model.to(device) model.eval() all_embs = [] for img in tqdm(loader): img = img.to(device) emb = model.forward(img) all_embs.append(emb.detach().cpu().numpy()) all_embs = np.vstack(all_embs) embs_list = [x for x in all_embs] metadata_df["embedding"] = embs_list return metadata_df class StarReLU(nn.Module): """ StarReLU: s * relu(x) ** 2 + b """ def __init__( self, scale_value=1.0, bias_value=0.0, scale_learnable=True, bias_learnable=True, mode=None, inplace=False, ): super().__init__() self.inplace = inplace self.relu = nn.ReLU(inplace=inplace) self.scale = nn.Parameter( scale_value * torch.ones(1), requires_grad=scale_learnable ) self.bias = nn.Parameter( bias_value * torch.ones(1), requires_grad=bias_learnable ) def forward(self, x): return self.scale * self.relu(x) ** 2 + self.bias class FungiMEEModel(nn.Module): def __init__( self, num_classes=NUM_CLASSES, dim=1024, ): super().__init__() print("Setting up Pytorch Model") self.device = torch.device('cpu') print(f"Using devide: {self.device}") self.date_embedding = MlpHead( dim=DATE_SIZE, num_classes=dim, mlp_ratio=128, act_layer=StarReLU ) self.geo_embedding = MlpHead( dim=GEO_SIZE, num_classes=dim, mlp_ratio=128, act_layer=StarReLU ) self.substr_embedding = MlpHead( dim=SUBSTRATE_SIZE, num_classes=dim, mlp_ratio=8, act_layer=StarReLU, ) self.encoder = nn.TransformerEncoder( nn.TransformerEncoderLayer(d_model=dim, nhead=8, batch_first=True), num_layers=4, ) self.head = MlpHead(dim=dim, num_classes=num_classes, drop_rate=0) for param in self.parameters(): if param.dim() > 1: nn.init.kaiming_normal_(param) def forward(self, img_emb, metadata): img_emb = img_emb.to(self.device) date_emb = self.date_embedding.forward(metadata["date"].to(self.device)) geo_emb = self.geo_embedding.forward(metadata["geo"].to(self.device)) substr_emb = self.substr_embedding.forward(metadata["substr"].to(self.device)) full_emb = torch.stack((img_emb, date_emb, geo_emb, substr_emb), dim=1) cls_emb = self.encoder.forward(full_emb)[:, 0, :].squeeze(1) return self.head.forward(cls_emb) def predict(self, img_emb, metadata): logits = self.forward(img_emb, metadata) return logits.argmax(1).tolist() class FungiEnsembleModel(nn.Module): def __init__(self, models) -> None: super().__init__() self.models = nn.ModuleList() self.device = torch.device('cpu') for model in models: model = model.to(self.device) model.eval() self.models.append(model) def forward(self, img_emb, metadata): img_emb = img_emb.to(self.device) probs = [] for model in self.models: logits = model.forward(img_emb, metadata) p = logits.softmax(dim=1).detach().cpu() probs.append(p) return torch.stack(probs).mean(dim=0) def predict(self, img_emb, metadata): logits = self.forward(img_emb, metadata) # Any preprocess happens here return logits.argmax(1).tolist() def make_submission(metadata_df): OUTPUT_CSV_PATH = "./submission.csv" BASE_CKPT_PATH = "./checkpoints" model_names = [ "dino_2_optuna_05242231.ckpt", "dino_optuna_05241449.ckpt", "dino_optuna_05241257.ckpt", "dino_optuna_05241222.ckpt", "dino_2_optuna_05242055.ckpt", "dino_2_optuna_05242156.ckpt", "dino_2_optuna_05242344.ckpt", ] models = [] for model_path in model_names: print("loading ", model_path) ckpt_path = os.path.join(BASE_CKPT_PATH, model_path) ckpt = torch.load(ckpt_path) model = FungiMEEModel() model.load_state_dict( {w: ckpt["model." + w] for w in model.state_dict().keys()} ) model.eval() models.append(model) fungi_model = FungiEnsembleModel(models) # ckpt_path = os.path.join(BASE_CKPT_PATH, "dino_2_optuna_05242055.ckpt") # fungi_model = FungiMEEModel() # ckpt = torch.load(ckpt_path) # fungi_model.load_state_dict( # {w: ckpt["model." + w] for w in fungi_model.state_dict().keys()} # ) embedding_dataset = EmbeddingMetadataDataset(metadata_df) loader = DataLoader(embedding_dataset, batch_size=128, shuffle=False) preds = [] for data in tqdm(loader): emb, metadata = data pred = fungi_model.forward(emb, metadata) preds.append(pred) all_preds = torch.vstack(preds).numpy() preds_df = metadata_df[["observation_id", "image_path"]] preds_df["preds"] = [i for i in all_preds] preds_df = ( preds_df[["observation_id", "preds"]] .groupby("observation_id") .mean() .reset_index() ) preds_df["class_id"] = preds_df["preds"].apply( lambda x: x.argmax() if x.argmax() <= 1603 else -1 ) preds_df[["observation_id", "class_id"]].to_csv(OUTPUT_CSV_PATH, index=None) print("Submission complete") if __name__ == "__main__": # # # # # # Real submission import zipfile with zipfile.ZipFile("/tmp/data/private_testset.zip", "r") as zip_ref: zip_ref.extractall("/tmp/data/") metadata_file_path = "./_test_preprocessed.csv" root_dir = "/tmp/data/private_testset" # Test submission # metadata_file_path = "../trial_submission.csv" # root_dir = "../data/DF_FULL" ############## metadata_df = generate_embeddings(metadata_file_path, root_dir) make_submission(metadata_df)