Anthony Miyaguchi commited on
Commit
2076935
·
1 Parent(s): b940c12

Add initial submission files

Browse files
Files changed (5) hide show
  1. __init__.py +0 -0
  2. data.py +69 -0
  3. model.py +81 -0
  4. submission.py +41 -0
  5. test_evaluate.py +82 -0
__init__.py ADDED
File without changes
data.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import pandas as pd
4
+ import pytorch_lightning as pl
5
+ import torch
6
+ from PIL import Image
7
+ from torch.utils.data import DataLoader, Dataset
8
+ from torchvision.transforms import v2
9
+ from transformers import AutoImageProcessor, AutoModel
10
+
11
+
12
+ class TransformDino(v2.Transform):
13
+ def __init__(self, model_name="facebook/dinov2-base"):
14
+ super().__init__()
15
+ self.processor = AutoImageProcessor.from_pretrained(model_name)
16
+ self.model = AutoModel.from_pretrained(model_name)
17
+
18
+ def forward(self, batch):
19
+ model_inputs = self.processor(images=batch["features"], return_tensors="pt")
20
+ with torch.no_grad():
21
+ outputs = self.model(**model_inputs)
22
+ last_hidden_states = outputs.last_hidden_state
23
+ # extract the cls token
24
+ batch["features"] = last_hidden_states[:, 0]
25
+ return batch
26
+
27
+
28
+ class ImageDataset(Dataset):
29
+ def __init__(self, metadata_path, images_root_path):
30
+ self.metadata_path = metadata_path
31
+ self.metadata = pd.read_csv(metadata_path)
32
+ self.images_root_path = images_root_path
33
+
34
+ def __len__(self):
35
+ return len(self.metadata)
36
+
37
+ def __getitem__(self, idx):
38
+ row = self.metadata.iloc[idx]
39
+ image_path = Path(self.images_root_path) / row.image_path
40
+ img = Image.open(image_path).convert("RGB")
41
+ img = v2.ToTensor()(img)
42
+ return {"features": img, "observation_id": row.observation_id}
43
+
44
+
45
+ class InferenceDataModel(pl.LightningDataModule):
46
+ def __init__(
47
+ self,
48
+ metadata_path,
49
+ images_root_path,
50
+ batch_size=32,
51
+ ):
52
+ super().__init__()
53
+ self.metadata_path = metadata_path
54
+ self.images_root_path = images_root_path
55
+ self.batch_size = batch_size
56
+
57
+ def setup(self, stage=None):
58
+ self.dataloader = DataLoader(
59
+ ImageDataset(self.metadata_path, self.images_root_path),
60
+ batch_size=self.batch_size,
61
+ shuffle=False,
62
+ num_workers=4,
63
+ )
64
+
65
+ def predict_dataloader(self):
66
+ transform = v2.Compose([TransformDino("facebook/dinov2-base")])
67
+ for batch in self.dataloader:
68
+ batch = transform(batch)
69
+ yield batch
model.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch
3
+ from torch import nn
4
+ from torchmetrics.classification import (
5
+ MulticlassAccuracy,
6
+ MulticlassF1Score,
7
+ MulticlassPrecision,
8
+ MulticlassRecall,
9
+ )
10
+
11
+
12
+ class LinearClassifier(pl.LightningModule):
13
+ def __init__(self, num_features, num_classes):
14
+ super().__init__()
15
+ self.num_features = num_features
16
+ self.num_classes = num_classes
17
+ self.save_hyperparameters() # Saves hyperparams in the checkpoints
18
+ self.model = nn.Linear(num_features, num_classes)
19
+ self.learning_rate = 0.002
20
+ self.accuracy = MulticlassAccuracy(num_classes=num_classes, average="weighted")
21
+ self.f1_score = MulticlassF1Score(num_classes=num_classes, average="weighted")
22
+ self.precision = MulticlassPrecision(
23
+ num_classes=num_classes, average="weighted"
24
+ )
25
+ self.recall = MulticlassRecall(num_classes=num_classes, average="weighted")
26
+
27
+ def forward(self, x):
28
+ return torch.log_softmax(self.model(x), dim=1)
29
+
30
+ def configure_optimizers(self):
31
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
32
+ return optimizer
33
+
34
+ def _run_step(self, batch, batch_idx, step_name):
35
+ x, y = batch["features"], batch["label"]
36
+ logits = self(x)
37
+ loss = torch.nn.functional.nll_loss(logits, y)
38
+ self.log(f"{step_name}_loss", loss, prog_bar=True)
39
+ self.log(
40
+ f"{step_name}_accuracy",
41
+ self.accuracy(logits, y),
42
+ on_step=False,
43
+ on_epoch=True,
44
+ )
45
+ if step_name != "train":
46
+ self.log(
47
+ f"{step_name}_f1",
48
+ self.f1_score(logits, y),
49
+ on_step=False,
50
+ on_epoch=True,
51
+ )
52
+ self.log(
53
+ f"{step_name}_precision",
54
+ self.precision(logits, y),
55
+ on_step=False,
56
+ on_epoch=True,
57
+ )
58
+ self.log(
59
+ f"{step_name}_recall",
60
+ self.recall(logits, y),
61
+ on_step=False,
62
+ on_epoch=True,
63
+ )
64
+ return loss
65
+
66
+ def training_step(self, batch, batch_idx):
67
+ return self._run_step(batch, batch_idx, "train")
68
+
69
+ def validation_step(self, batch, batch_idx):
70
+ return self._run_step(batch, batch_idx, "val")
71
+
72
+ def test_step(self, batch, batch_idx):
73
+ return self._run_step(batch, batch_idx, "test")
74
+
75
+ def predict_step(self, batch, batch_idx, dataloader_idx=None):
76
+ logits = self(batch["features"])
77
+ return {
78
+ "logits": logits,
79
+ "class_id": torch.argmax(logits, dim=1),
80
+ "observation_id": batch["observation_id"],
81
+ }
submission.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import zipfile
2
+
3
+ import pandas as pd
4
+ import torch
5
+ from pytorch_lightning import Trainer
6
+
7
+ from .data import InferenceDataModel
8
+ from .model import LinearClassifier
9
+
10
+
11
+ def make_submission(
12
+ test_metadata,
13
+ model_path,
14
+ output_csv_path="./submission.csv",
15
+ images_root_path="/tmp/data/private_testset",
16
+ ):
17
+ model = LinearClassifier.load_from_checkpoint(model_path)
18
+ dm = InferenceDataModel(
19
+ metadata_path=test_metadata, images_root_path=images_root_path
20
+ )
21
+ trainer = Trainer(
22
+ accelerator="gpu" if torch.cuda.is_available() else "cpu",
23
+ )
24
+ predictions = trainer.predict(model, datamodule=dm)
25
+ rows = []
26
+ for batch in predictions:
27
+ for observation_id, class_id in zip(batch["observation_id"], batch["class_id"]):
28
+ row = {"observation_id": int(observation_id), "class_id": int(class_id)}
29
+ rows.append(row)
30
+ submission_df = pd.DataFrame(rows)
31
+ submission_df.to_csv(output_csv_path, index=False)
32
+
33
+
34
+ if __name__ == "__main__":
35
+ with zipfile.ZipFile("/tmp/data/private_testset.zip", "r") as zip_ref:
36
+ zip_ref.extractall("/tmp/data")
37
+
38
+ MODEL_PATH = "last.ckpt"
39
+ metadata_file_path = "./SnakeCLEF2024-TestMetadata.csv"
40
+ test_metadata = pd.read_csv(metadata_file_path)
41
+ make_submission(test_metadata=test_metadata, model_path=MODEL_PATH)
test_evaluate.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import PIL
4
+ import pytest
5
+ import torch
6
+ from pytorch_lightning import Trainer
7
+
8
+ from .data import ImageDataset, InferenceDataModel
9
+ from .model import LinearClassifier
10
+ from .submission import make_submission
11
+
12
+
13
+ class TestingInferenceDataModel(InferenceDataModel):
14
+ def train_dataloader(self):
15
+ for batch in self.predict_dataloader():
16
+ # add a label to the batch with classes from 0 to 9
17
+ batch["label"] = torch.randint(0, 10, (batch["features"].shape[0],))
18
+ yield batch
19
+
20
+
21
+ @pytest.fixture
22
+ def images_root(tmp_path):
23
+ images_root = tmp_path / "images"
24
+ images_root.mkdir()
25
+ for i in range(10):
26
+ img = PIL.Image.fromarray(
27
+ np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
28
+ )
29
+ img.save(images_root / f"{i}.jpg")
30
+ return images_root
31
+
32
+
33
+ @pytest.fixture
34
+ def metadata(tmp_path, images_root):
35
+ res = []
36
+ for i, img in enumerate(images_root.glob("*.jpg")):
37
+ res.append({"image_path": img.name, "observation_id": i})
38
+ df = pd.DataFrame(res)
39
+ df.to_csv(tmp_path / "metadata.csv", index=False)
40
+ return tmp_path / "metadata.csv"
41
+
42
+
43
+ @pytest.fixture
44
+ def model_checkpoint(tmp_path, metadata, images_root):
45
+ model_checkpoint = tmp_path / "model.ckpt"
46
+ model = LinearClassifier(768, 10)
47
+ trainer = Trainer(max_epochs=1, fast_dev_run=True)
48
+ dm = TestingInferenceDataModel(metadata, images_root)
49
+ trainer.fit(model, dm)
50
+ trainer.save_checkpoint(model_checkpoint)
51
+ return model_checkpoint
52
+
53
+
54
+ def test_image_dataset(images_root, metadata):
55
+ dataset = ImageDataset(metadata, images_root)
56
+ assert len(dataset) == 10
57
+ for i in range(10):
58
+ assert dataset[i]["features"].shape == torch.Size([3, 100, 100])
59
+
60
+
61
+ def test_inference_datamodel(images_root, metadata):
62
+ batch_size = 5
63
+ model = InferenceDataModel(metadata, images_root, batch_size=batch_size)
64
+ model.setup()
65
+ assert len(model.dataloader) == 2
66
+ for batch in model.predict_dataloader():
67
+ assert set(batch.keys()) == {"features", "observation_id"}
68
+ assert batch["features"].shape == torch.Size([batch_size, 768])
69
+
70
+
71
+ def test_model_checkpoint(model_checkpoint):
72
+ model = LinearClassifier.load_from_checkpoint(model_checkpoint)
73
+ assert model
74
+
75
+
76
+ def test_make_submission(model_checkpoint, metadata, images_root, tmp_path):
77
+ output_csv_path = tmp_path / "submission.csv"
78
+ make_submission(metadata, model_checkpoint, output_csv_path, images_root)
79
+ submission_df = pd.read_csv(output_csv_pathgit)
80
+ assert len(submission_df) == 10
81
+ assert set(submission_df.columns) == {"observation_id", "class_id"}
82
+ assert submission_df["class_id"].isin(range(10)).all()