|
|
|
import zipfile |
|
from argparse import ArgumentParser |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
from PIL import Image |
|
from torch import nn |
|
from torch.utils.data import DataLoader, Dataset |
|
from transformers import AutoImageProcessor, AutoModel |
|
|
|
|
|
class ImageDataset(Dataset): |
|
def __init__(self, metadata_path, images_root_path, model_name="./dinov2"): |
|
self.metadata_path = metadata_path |
|
self.metadata = pd.read_csv(metadata_path) |
|
self.images_root_path = images_root_path |
|
self.processor = AutoImageProcessor.from_pretrained(model_name) |
|
self.model = AutoModel.from_pretrained(model_name) |
|
|
|
def __len__(self): |
|
return len(self.metadata) |
|
|
|
def __getitem__(self, idx): |
|
row = self.metadata.iloc[idx] |
|
image_path = Path(self.images_root_path) / row.filename |
|
|
|
model_inputs = self.processor( |
|
images=Image.open(image_path), return_tensors="pt" |
|
) |
|
with torch.no_grad(): |
|
outputs = self.model(**model_inputs) |
|
last_hidden_states = outputs.last_hidden_state |
|
|
|
return { |
|
"features": last_hidden_states[0, 0], |
|
"observation_id": row.observation_id, |
|
} |
|
|
|
|
|
class LinearClassifier(nn.Module): |
|
def __init__(self, num_features, num_classes): |
|
super().__init__() |
|
self.num_features = num_features |
|
self.num_classes = num_classes |
|
self.model = nn.Linear(num_features, num_classes) |
|
|
|
def forward(self, x): |
|
return torch.log_softmax(self.model(x), dim=1) |
|
|
|
|
|
def make_submission( |
|
test_metadata, |
|
model_path, |
|
output_csv_path="./submission.csv", |
|
images_root_path="/tmp/data/private_testset", |
|
): |
|
checkpoint = torch.load(model_path) |
|
hparams = checkpoint["hyper_parameters"] |
|
model = LinearClassifier(hparams["num_features"], hparams["num_classes"]) |
|
model.load_state_dict(checkpoint["state_dict"]) |
|
|
|
dataloader = DataLoader( |
|
ImageDataset(test_metadata, images_root_path), batch_size=32 |
|
) |
|
rows = [] |
|
for batch in dataloader: |
|
observation_ids = batch["observation_id"] |
|
logits = model(batch["features"]) |
|
class_ids = torch.argmax(logits, dim=1) |
|
for observation_id, class_id in zip(observation_ids, class_ids): |
|
row = {"observation_id": int(observation_id), "class_id": int(class_id)} |
|
rows.append(row) |
|
submission_df = pd.DataFrame(rows).drop_duplicates("observation_id", keep="first") |
|
submission_df.to_csv(output_csv_path, index=False) |
|
|
|
|
|
def parse_args(): |
|
parser = ArgumentParser() |
|
parser.add_argument("--model-path", type=str, default="./last.ckpt") |
|
parser.add_argument( |
|
"--metadata-file-path", type=str, default="./SnakeCLEF2024-TestMetadata.csv" |
|
) |
|
return parser.parse_args() |
|
|
|
|
|
if __name__ == "__main__": |
|
args = parse_args() |
|
with zipfile.ZipFile("/tmp/data/private_testset.zip", "r") as zip_ref: |
|
zip_ref.extractall("/tmp/data") |
|
|
|
make_submission(test_metadata=args.metadata_file_path, model_path=args.model_path) |
|
|