dsgt-snakeclef / submission.py
Anthony Miyaguchi
Add initial submission files
2076935
raw
history blame
1.33 kB
import zipfile
import pandas as pd
import torch
from pytorch_lightning import Trainer
from .data import InferenceDataModel
from .model import LinearClassifier
def make_submission(
test_metadata,
model_path,
output_csv_path="./submission.csv",
images_root_path="/tmp/data/private_testset",
):
model = LinearClassifier.load_from_checkpoint(model_path)
dm = InferenceDataModel(
metadata_path=test_metadata, images_root_path=images_root_path
)
trainer = Trainer(
accelerator="gpu" if torch.cuda.is_available() else "cpu",
)
predictions = trainer.predict(model, datamodule=dm)
rows = []
for batch in predictions:
for observation_id, class_id in zip(batch["observation_id"], batch["class_id"]):
row = {"observation_id": int(observation_id), "class_id": int(class_id)}
rows.append(row)
submission_df = pd.DataFrame(rows)
submission_df.to_csv(output_csv_path, index=False)
if __name__ == "__main__":
with zipfile.ZipFile("/tmp/data/private_testset.zip", "r") as zip_ref:
zip_ref.extractall("/tmp/data")
MODEL_PATH = "last.ckpt"
metadata_file_path = "./SnakeCLEF2024-TestMetadata.csv"
test_metadata = pd.read_csv(metadata_file_path)
make_submission(test_metadata=test_metadata, model_path=MODEL_PATH)