File size: 3,333 Bytes
867532a
df2ac53
867532a
d41c4d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
867532a
 
 
 
 
 
 
 
 
 
df2ac53
 
867532a
df2ac53
 
 
867532a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#!/usr/bin/env python
import zipfile
from argparse import ArgumentParser
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from PIL import Image
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoImageProcessor, AutoModel


class ImageDataset(Dataset):
    def __init__(self, metadata_path, images_root_path):
        self.metadata_path = metadata_path
        self.metadata = pd.read_csv(metadata_path)
        self.images_root_path = images_root_path

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        row = self.metadata.iloc[idx]
        image_path = Path(self.images_root_path) / row.filename
        img = Image.open(image_path)
        img = torch.from_numpy(np.array(img))
        return {"features": img, "observation_id": row.observation_id}


class LinearClassifier(nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.num_features = num_features
        self.num_classes = num_classes
        self.model = nn.Linear(num_features, num_classes)

    def forward(self, x):
        return torch.log_softmax(self.model(x), dim=1)


class TransformDino:
    def __init__(self, model_name="facebook/dinov2-base"):
        self.processor = AutoImageProcessor.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)

    def forward(self, batch):
        model_inputs = self.processor(images=batch["features"], return_tensors="pt")
        with torch.no_grad():
            outputs = self.model(**model_inputs)
            last_hidden_states = outputs.last_hidden_state
        # extract the cls token
        batch["features"] = last_hidden_states[:, 0]
        return batch


def make_submission(
    test_metadata,
    model_path,
    output_csv_path="./submission.csv",
    images_root_path="/tmp/data/private_testset",
):
    checkpoint = torch.load(model_path)
    hparams = checkpoint["hyper_parameters"]
    model = LinearClassifier(hparams["num_features"], hparams["num_classes"])
    model.load_state_dict(checkpoint["state_dict"])

    transform = TransformDino()
    dataloader = DataLoader(
        ImageDataset(test_metadata, images_root_path), batch_size=32, num_workers=4
    )
    rows = []
    for batch in dataloader:
        batch = transform.forward(batch)
        observation_ids = batch["observation_id"]
        logits = model(batch["features"])
        class_ids = torch.argmax(logits, dim=1)
        for observation_id, class_id in zip(observation_ids, class_ids):
            row = {"observation_id": int(observation_id), "class_id": int(class_id)}
            rows.append(row)
    submission_df = pd.DataFrame(rows)
    submission_df.to_csv(output_csv_path, index=False)


def parse_args():
    parser = ArgumentParser()
    parser.add_argument("--model-path", type=str, default="./last.ckpt")
    parser.add_argument(
        "--metadata-file-path", type=str, default="./SnakeCLEF2024-TestMetadata.csv"
    )
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    with zipfile.ZipFile("/tmp/data/private_testset.zip", "r") as zip_ref:
        zip_ref.extractall("/tmp/data")

    make_submission(test_metadata=args.metadata_file_path, model_path=args.model_path)