|
import pandas as pd |
|
import numpy as np |
|
import os |
|
import sys |
|
from tqdm import tqdm |
|
import timm |
|
import torchvision.transforms as T |
|
from PIL import Image |
|
import torch |
|
from multiprocessing import Pool |
|
|
|
from mmpretrain.apis import ImageClassificationInferencer, FeatureExtractor |
|
|
|
import mmpretrain.utils.progress as progress |
|
progress.disable_progress_bar = True |
|
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) |
|
|
|
|
|
def load_image(path : str, images_root_path="/tmp/data/private_testset"): |
|
return np.array(Image.open(os.path.join(images_root_path, path)))[:, :, ::-1] |
|
|
|
def rerank_poison(posison_status_list : pd.DataFrame, pred_scores : np.array) -> tuple[int, float]: |
|
class_id = np.argmax(pred_scores) |
|
class_score = np.max(pred_scores) |
|
|
|
poisonous = posison_status_list.copy() |
|
poisonous['score'] = pred_scores |
|
poisonous.sort_values(by=['score'], ascending=False, inplace=True) |
|
first_poisonous = poisonous[poisonous['poisonous'] == 1].iloc[0] |
|
|
|
if 13 * first_poisonous['score'] > class_score: |
|
class_id = first_poisonous['class_id'] |
|
class_score = first_poisonous['score'] |
|
|
|
return class_id, class_score |
|
|
|
|
|
def make_submission(test_metadata, model_path, model_name, output_csv_path="./submission.csv", images_root_path="/tmp/data/private_testset"): |
|
"""Make submission with given """ |
|
|
|
|
|
feature_extractor = FeatureExtractor(model=model_name, pretrained=model_path, device="cuda:0") |
|
|
|
predictions = [] |
|
prediction_scores = [] |
|
prediction_scores_dict = {} |
|
prediction_feats_dict = {} |
|
obs_imgs_dict = {} |
|
|
|
BATCH_SIZE = 4 |
|
p = Pool(BATCH_SIZE) |
|
|
|
|
|
for i in tqdm(range(int(np.ceil(test_metadata.shape[0] / BATCH_SIZE)))): |
|
|
|
|
|
|
|
img_paths_batch = test_metadata['image_path'][(i) * BATCH_SIZE:(i+1) * BATCH_SIZE] |
|
batch_imgs = p.map(load_image, img_paths_batch) |
|
|
|
|
|
feats = feature_extractor(batch_imgs, batch_size=BATCH_SIZE) |
|
feats = (torch.stack([x[0] for x in feats], dim=0),) |
|
results = feature_extractor.model.head.task_heads['species'].predict(feats, img_paths=img_paths_batch) |
|
for res, f, obs_id, img_path in zip(results, feats[0], test_metadata['observation_id'][(i) * BATCH_SIZE:(i+1) * BATCH_SIZE], img_paths_batch): |
|
|
|
pred_scores = res.pred_score.detach().cpu().numpy() |
|
|
|
predictions.append(np.argmax(pred_scores)) |
|
prediction_scores.append(pred_scores) |
|
prediction_scores_dict.setdefault(obs_id, []).append(pred_scores) |
|
prediction_feats_dict.setdefault(obs_id, []).append(f) |
|
obs_imgs_dict[obs_id] = img_path |
|
|
|
print('finished inference') |
|
|
|
test_metadata["class_id"] = predictions |
|
test_metadata["max_score"] = prediction_scores |
|
|
|
poison_status_list = pd.read_csv('poison_status_list.csv') |
|
poison_status_list = poison_status_list.sort_values(by=['class_id']) |
|
|
|
poison_classes = set(poison_status_list[poison_status_list['poisonous'] == 1]['class_id']) |
|
|
|
for obs_id, pred_feats in tqdm(prediction_feats_dict.items()): |
|
|
|
|
|
|
|
fusion_feats = torch.mean(torch.stack(pred_feats, dim=0), dim=0, keepdim=True) |
|
results = feature_extractor.model.head.task_heads['species'].predict((fusion_feats,), img_paths=[obs_imgs_dict[obs_id]]) |
|
fusion_scores = results[0].pred_score.detach().cpu().numpy() |
|
class_score = np.max(fusion_scores) |
|
class_id = np.argmax(fusion_scores) |
|
class_id, class_score = rerank_poison(poison_status_list, fusion_scores) |
|
entropy = -np.sum(fusion_scores * np.log(fusion_scores)) |
|
if entropy > 7 or (class_id not in poison_classes and entropy > 2.5): |
|
class_id = -1 |
|
|
|
|
|
test_metadata.loc[test_metadata["observation_id"] == obs_id, "class_id"] = class_id |
|
test_metadata.loc[test_metadata["observation_id"] == obs_id, "max_score"] = class_score |
|
|
|
user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first") |
|
user_pred_df[["observation_id", "class_id"]].to_csv(output_csv_path, index=None) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
import zipfile |
|
|
|
with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref: |
|
zip_ref.extractall("/tmp/data") |
|
|
|
MODEL_PATH = "swinv2_base_w24_b16x8-fp16_fungi+val2_res_384_genus-loss_no-unknown-target-zero_metadata_epochs_4_epoch_2_20240524-a429ecac.pth" |
|
MODEL_NAME = "swinv2_base_w24_b16x8-fp16_fungi+val2_res_384_genus-loss_no-unknown-target-zero_metadata_epochs_4.py" |
|
|
|
metadata_file_path = "./FungiCLEF2024_TestMetadata.csv" |
|
test_metadata = pd.read_csv(metadata_file_path) |
|
|
|
make_submission( |
|
test_metadata=test_metadata, |
|
model_path=MODEL_PATH, |
|
model_name=MODEL_NAME |
|
) |
|
|