Commit
·
fa626df
1
Parent(s):
ee66a83
Added challenge inference script
Browse files- Dockerfile +22 -0
- docker-compose.yml +24 -0
- evaluate.py +211 -0
- predict.py +70 -0
- requirements.txt +3 -0
- start.sh +11 -0
- tools/test_generate_result_pre-consensus.py +35 -68
Dockerfile
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvcr.io/nvidia/pytorch:22.12-py3
|
2 |
+
# FROM python:3.8.10-slim
|
3 |
+
# alternatively use python image as a base image, if PyTorch with GPU drivers is not needed
|
4 |
+
|
5 |
+
#COPY requirements.txt /script/requirements.txt
|
6 |
+
#COPY predict.py /script/predict.py
|
7 |
+
#COPY evaluate.py /script/evaluate.py
|
8 |
+
#COPY start.sh /script/start.sh
|
9 |
+
#COPY ./FungiCLEF2023-ViT_base_patch16_224-100E.pth /script/FungiCLEF2023-ViT_base_patch16_224-100E.pth
|
10 |
+
|
11 |
+
COPY . /script/
|
12 |
+
|
13 |
+
# install python dependencies
|
14 |
+
ENV SCRIPT_DIR='/script'
|
15 |
+
WORKDIR $SCRIPT_DIR
|
16 |
+
RUN pip install --no-cache-dir --upgrade pip build && \
|
17 |
+
pip install --no-cache-dir --compile -r requirements.txt && \
|
18 |
+
mim install "mmpretrain==1.0.0rc7" && \
|
19 |
+
rm -rf /var/lib/apt/lists/* /var/cache/apt/* /tmp/* /var/tmp/*
|
20 |
+
|
21 |
+
# run script
|
22 |
+
CMD bash start.sh
|
docker-compose.yml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
version: "3.8"
|
2 |
+
|
3 |
+
services:
|
4 |
+
fungiclef:
|
5 |
+
image: fungiclef-example:latest
|
6 |
+
build: .
|
7 |
+
container_name: fungiclef-example
|
8 |
+
volumes:
|
9 |
+
- /media/Data-10T-1/Data/DF21_300:/Data
|
10 |
+
# settings related to nvcr.io/nvidia/pytorch:22.12-py3 docker image
|
11 |
+
ipc: "host"
|
12 |
+
ulimits:
|
13 |
+
memlock: -1
|
14 |
+
stack: 67108864
|
15 |
+
deploy:
|
16 |
+
resources:
|
17 |
+
limits:
|
18 |
+
cpus: "4.0"
|
19 |
+
memory: "12g"
|
20 |
+
reservations:
|
21 |
+
devices:
|
22 |
+
- capabilities: [ "gpu" ]
|
23 |
+
driver: nvidia
|
24 |
+
count: all
|
evaluate.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
from sklearn.metrics import f1_score
|
6 |
+
|
7 |
+
COLUMNS = ["observationID", "class_id"]
|
8 |
+
poisonous_lvl = pd.read_csv(
|
9 |
+
"http://ptak.felk.cvut.cz/plants//DanishFungiDataset/poison_status_list.csv"
|
10 |
+
)
|
11 |
+
POISONOUS_SPECIES = poisonous_lvl[poisonous_lvl["poisonous"] == 1].class_id.unique()
|
12 |
+
|
13 |
+
|
14 |
+
def classification_error_with_unknown(
|
15 |
+
merged_df, cost_unkwnown_misclassified=10, cost_misclassified_as_unknown=0.1
|
16 |
+
):
|
17 |
+
num_misclassified_unknown = sum((merged_df.class_id_gt == -1) & (merged_df.class_id_pred != -1))
|
18 |
+
num_misclassified_as_unknown = sum(
|
19 |
+
(merged_df.class_id_gt != -1) & (merged_df.class_id_pred == -1)
|
20 |
+
)
|
21 |
+
num_misclassified_other = sum(
|
22 |
+
(merged_df.class_id_gt != merged_df.class_id_pred)
|
23 |
+
& (merged_df.class_id_pred != -1)
|
24 |
+
& (merged_df.class_id_gt != -1)
|
25 |
+
)
|
26 |
+
return (
|
27 |
+
num_misclassified_other
|
28 |
+
+ num_misclassified_unknown * cost_unkwnown_misclassified
|
29 |
+
+ num_misclassified_as_unknown * cost_misclassified_as_unknown
|
30 |
+
) / len(merged_df)
|
31 |
+
|
32 |
+
|
33 |
+
def classification_error(merged_df):
|
34 |
+
return classification_error_with_unknown(
|
35 |
+
merged_df, cost_misclassified_as_unknown=1, cost_unkwnown_misclassified=1
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
def num_psc_decisions(merged_df):
|
40 |
+
# Number of observations that were misclassified as edible, when in fact they are poisonous
|
41 |
+
num_psc = sum(
|
42 |
+
merged_df.class_id_gt.isin(POISONOUS_SPECIES)
|
43 |
+
& ~merged_df.class_id_pred.isin(POISONOUS_SPECIES)
|
44 |
+
)
|
45 |
+
return num_psc
|
46 |
+
|
47 |
+
|
48 |
+
def num_esc_decisions(merged_df):
|
49 |
+
# Number of observations that were misclassified as poisonus, when in fact they are edible
|
50 |
+
num_esc = sum(
|
51 |
+
~merged_df.class_id_gt.isin(POISONOUS_SPECIES)
|
52 |
+
& merged_df.class_id_pred.isin(POISONOUS_SPECIES)
|
53 |
+
)
|
54 |
+
return num_esc
|
55 |
+
|
56 |
+
|
57 |
+
def psc_esc_cost_score(merged_df, cost_psc=100, cost_esc=1):
|
58 |
+
return (
|
59 |
+
cost_psc * num_psc_decisions(merged_df) + cost_esc * num_esc_decisions(merged_df)
|
60 |
+
) / len(merged_df)
|
61 |
+
|
62 |
+
|
63 |
+
def evaluate_csv(test_annotation_file: str, user_submission_file: str) -> List[dict]:
|
64 |
+
# load gt annotations
|
65 |
+
gt_df = pd.read_csv(test_annotation_file, sep=",")
|
66 |
+
for col in COLUMNS:
|
67 |
+
assert col in gt_df, f"Test annotation file is missing column '{col}'."
|
68 |
+
# keep only observation-based predictions
|
69 |
+
gt_df = gt_df.drop_duplicates("observationID")
|
70 |
+
|
71 |
+
# load user predictions
|
72 |
+
try:
|
73 |
+
is_tsv = user_submission_file.endswith(".tsv")
|
74 |
+
user_pred_df = pd.read_csv(user_submission_file, sep="\t" if is_tsv else ",")
|
75 |
+
except Exception:
|
76 |
+
print("Could not read file submitted by the user.")
|
77 |
+
raise ValueError("Could not read file submitted by the user.")
|
78 |
+
|
79 |
+
# validate user predictions
|
80 |
+
for col in COLUMNS:
|
81 |
+
if col not in user_pred_df:
|
82 |
+
print(f"File submitted by the user is missing column '{col}'.")
|
83 |
+
raise ValueError(f"File submitted by the user is missing column '{col}'.")
|
84 |
+
if len(gt_df) != len(user_pred_df):
|
85 |
+
print(f"File submitted by the user should have {len(gt_df)} records.")
|
86 |
+
raise ValueError(f"File submitted by the user should have {len(gt_df)} records.")
|
87 |
+
missing_obs = gt_df.loc[
|
88 |
+
~gt_df["observationID"].isin(user_pred_df["observationID"]),
|
89 |
+
"observationID",
|
90 |
+
]
|
91 |
+
if len(missing_obs) > 0:
|
92 |
+
if len(missing_obs) > 3:
|
93 |
+
missing_obs_str = ", ".join(missing_obs.iloc[:3].astype(str)) + ", ..."
|
94 |
+
else:
|
95 |
+
missing_obs_str = ", ".join(missing_obs.astype(str))
|
96 |
+
print(f"File submitted by the user is missing observations: {missing_obs_str}")
|
97 |
+
raise ValueError(f"File submitted by the user is missing observations: {missing_obs_str}")
|
98 |
+
|
99 |
+
# merge dataframes
|
100 |
+
merged_df = pd.merge(
|
101 |
+
gt_df,
|
102 |
+
user_pred_df,
|
103 |
+
how="outer",
|
104 |
+
on="observationID",
|
105 |
+
validate="one_to_one",
|
106 |
+
suffixes=("_gt", "_pred"),
|
107 |
+
)
|
108 |
+
|
109 |
+
# evaluate accuracy_score and f1_score
|
110 |
+
cls_error = classification_error(merged_df)
|
111 |
+
cls_error_with_unknown = classification_error_with_unknown(merged_df)
|
112 |
+
psc_esc_cost = psc_esc_cost_score(merged_df)
|
113 |
+
|
114 |
+
result = [
|
115 |
+
{
|
116 |
+
"test_split": {
|
117 |
+
"F1 Score": np.round(
|
118 |
+
f1_score(merged_df["class_id_gt"], merged_df["class_id_pred"], average="macro")
|
119 |
+
* 100,
|
120 |
+
2,
|
121 |
+
),
|
122 |
+
"Track 1: Classification Error": np.round(cls_error, 4),
|
123 |
+
"Track 2: Cost for Poisonousness Confusion": np.round(psc_esc_cost, 4),
|
124 |
+
"Track 3: User-Focused Loss": np.round(cls_error + psc_esc_cost, 4),
|
125 |
+
"Track 4: Classification Error with Special Cost for Unknown": np.round(
|
126 |
+
cls_error_with_unknown, 4
|
127 |
+
),
|
128 |
+
}
|
129 |
+
}
|
130 |
+
]
|
131 |
+
|
132 |
+
print(f"Evaluated scores: {result[0]['test_split']}")
|
133 |
+
|
134 |
+
return result
|
135 |
+
|
136 |
+
|
137 |
+
def evaluate(test_annotation_file, user_submission_file, phase_codename, **kwargs):
|
138 |
+
"""
|
139 |
+
Evaluates the submission for a particular challenge phase and returns score
|
140 |
+
Arguments:
|
141 |
+
|
142 |
+
`test_annotations_file`: Path to test_annotation_file on the server
|
143 |
+
`user_submission_file`: Path to file submitted by the user
|
144 |
+
`phase_codename`: Phase to which submission is made
|
145 |
+
|
146 |
+
`**kwargs`: keyword arguments that contains additional submission
|
147 |
+
metadata that challenge hosts can use to send slack notification.
|
148 |
+
You can access the submission metadata
|
149 |
+
with kwargs['submission_metadata']
|
150 |
+
|
151 |
+
Example: A sample submission metadata can be accessed like this:
|
152 |
+
>>> print(kwargs['submission_metadata'])
|
153 |
+
{
|
154 |
+
'status': u'running',
|
155 |
+
'when_made_public': None,
|
156 |
+
'participant_team': 5,
|
157 |
+
'input_file': 'https://abc.xyz/path/to/submission/file.json',
|
158 |
+
'execution_time': u'123',
|
159 |
+
'publication_url': u'ABC',
|
160 |
+
'challenge_phase': 1,
|
161 |
+
'created_by': u'ABC',
|
162 |
+
'stdout_file': 'https://abc.xyz/path/to/stdout/file.json',
|
163 |
+
'method_name': u'Test',
|
164 |
+
'stderr_file': 'https://abc.xyz/path/to/stderr/file.json',
|
165 |
+
'participant_team_name': u'Test Team',
|
166 |
+
'project_url': u'http://foo.bar',
|
167 |
+
'method_description': u'ABC',
|
168 |
+
'is_public': False,
|
169 |
+
'submission_result_file': 'https://abc.xyz/path/result/file.json',
|
170 |
+
'id': 123,
|
171 |
+
'submitted_at': u'2017-03-20T19:22:03.880652Z'
|
172 |
+
}
|
173 |
+
"""
|
174 |
+
print("Starting Evaluation.....")
|
175 |
+
out = {}
|
176 |
+
if phase_codename == "prediction-based":
|
177 |
+
print("Evaluating for Prediction-based Phase")
|
178 |
+
out["result"] = evaluate_csv(test_annotation_file, user_submission_file)
|
179 |
+
|
180 |
+
# To display the results in the result file
|
181 |
+
out["submission_result"] = out["result"][0]["test_split"]
|
182 |
+
print("Completed evaluation")
|
183 |
+
return out
|
184 |
+
|
185 |
+
|
186 |
+
if __name__ == "__main__":
|
187 |
+
import argparse
|
188 |
+
import json
|
189 |
+
|
190 |
+
parser = argparse.ArgumentParser()
|
191 |
+
parser.add_argument(
|
192 |
+
"--test-annotation-file",
|
193 |
+
help="Path to test_annotation_file on the server.",
|
194 |
+
type=str,
|
195 |
+
required=True,
|
196 |
+
)
|
197 |
+
parser.add_argument(
|
198 |
+
"--user-submission-file",
|
199 |
+
help="Path to a file created by predict script.",
|
200 |
+
type=str,
|
201 |
+
required=True,
|
202 |
+
)
|
203 |
+
args = parser.parse_args()
|
204 |
+
|
205 |
+
result = evaluate(
|
206 |
+
test_annotation_file=args.test_annotation_file,
|
207 |
+
user_submission_file=args.user_submission_file,
|
208 |
+
phase_codename="prediction-based",
|
209 |
+
)
|
210 |
+
with open("scores.json", "w") as f:
|
211 |
+
json.dump(result, f)
|
predict.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path as osp
|
3 |
+
import subprocess
|
4 |
+
import sys
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
|
8 |
+
# custom script arguments
|
9 |
+
CONFIG_PATH = 'work_dirs/swin_base_b32x4-fp16_fungi+val_res_384_cb_epochs_6/swin_base_b32x4-fp16_fungi+val-test_res_384_cb_epochs_6.py'
|
10 |
+
CHECKPOINT_PATH = "work_dirs/swin_base_b32x4-fp16_fungi+val_res_384_cb_epochs_6/epoch_6.pth"
|
11 |
+
SCORE_THRESHOLD = 0.2
|
12 |
+
|
13 |
+
|
14 |
+
def run_inference(input_csv, output_csv, data_root_path):
|
15 |
+
"""Load model and dataloader and run inference."""
|
16 |
+
|
17 |
+
if not data_root_path.endswith('/'):
|
18 |
+
data_root_path += '/'
|
19 |
+
data_cfg_opts = [
|
20 |
+
f'test_dataloader.dataset.data_root=',
|
21 |
+
f'test_dataloader.dataset.ann_file={input_csv}',
|
22 |
+
f'test_dataloader.dataset.data_prefix={data_root_path}']
|
23 |
+
|
24 |
+
inference = subprocess.Popen([
|
25 |
+
'python', '-m',
|
26 |
+
'tools.test_generate_result_pre-consensus',
|
27 |
+
CONFIG_PATH, CHECKPOINT_PATH,
|
28 |
+
output_csv,
|
29 |
+
'--threshold', str(SCORE_THRESHOLD),
|
30 |
+
'--no-scores',
|
31 |
+
'--cfg-options'] + data_cfg_opts)
|
32 |
+
return_code = inference.wait()
|
33 |
+
if return_code != 0:
|
34 |
+
print(f'Inference crashed with exit code {return_code}')
|
35 |
+
sys.exit(return_code)
|
36 |
+
print(f'Written {output_csv}')
|
37 |
+
|
38 |
+
|
39 |
+
if __name__ == "__main__":
|
40 |
+
import argparse
|
41 |
+
|
42 |
+
parser = argparse.ArgumentParser()
|
43 |
+
parser.add_argument(
|
44 |
+
"--input-file",
|
45 |
+
help="Path to a file with observation ids and image paths.",
|
46 |
+
type=str,
|
47 |
+
required=True,
|
48 |
+
)
|
49 |
+
parser.add_argument(
|
50 |
+
"--data-root-path",
|
51 |
+
help="Path to a directory where images are stored.",
|
52 |
+
type=str,
|
53 |
+
required=True,
|
54 |
+
)
|
55 |
+
parser.add_argument(
|
56 |
+
"--output-file",
|
57 |
+
help="Path to a file where predict script will store predictions.",
|
58 |
+
type=str,
|
59 |
+
required=True,
|
60 |
+
)
|
61 |
+
args = parser.parse_args()
|
62 |
+
|
63 |
+
output_csv = os.path.basename(args.output_file)
|
64 |
+
if not output_csv.endswith(".csv"):
|
65 |
+
output_csv = output_csv + ".csv"
|
66 |
+
run_inference(
|
67 |
+
input_csv=args.input_file,
|
68 |
+
output_csv=output_csv,
|
69 |
+
data_root_path=args.data_root_path,
|
70 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
openmim
|
2 |
+
tensorboard
|
3 |
+
future
|
start.sh
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
METADATA_CSV='http://ptak.felk.cvut.cz/plants/DanishFungiDataset/FungiCLEF2023_val_metadata_PRODUCTION.csv'
|
4 |
+
DATA_ROOT_PATH='/Data'
|
5 |
+
python ./predict.py \
|
6 |
+
--input-file $METADATA_CSV \
|
7 |
+
--data-root-path $DATA_ROOT_PATH \
|
8 |
+
--output-file user_submission.csv && \
|
9 |
+
python ./evaluate.py \
|
10 |
+
--test-annotation-file $METADATA_CSV \
|
11 |
+
--user-submission-file user_submission.csv
|
tools/test_generate_result_pre-consensus.py
CHANGED
@@ -17,21 +17,11 @@ import torch
|
|
17 |
|
18 |
from torch.nn import DataParallel
|
19 |
|
20 |
-
|
21 |
-
from mmengine.model.wrappers import MMDistributedDataParallel
|
22 |
-
#from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
23 |
-
#from mmcv.runner import (load_checkpoint,
|
24 |
-
# wrap_fp16_model)
|
25 |
from mmengine.runner import load_checkpoint
|
26 |
from mmengine.registry import DefaultScope
|
27 |
-
|
28 |
-
|
29 |
-
from mmpretrain.datasets import build_dataset
|
30 |
from mmengine.runner import Runner
|
31 |
-
#from mmcls.datasets import build_dataloader, build_dataset
|
32 |
from mmpretrain.models import build_classifier
|
33 |
-
#from mmcls.models import build_classifier
|
34 |
-
#from mmcls.utils import get_root_logger, setup_multi_processes
|
35 |
|
36 |
|
37 |
def parse_args():
|
@@ -39,21 +29,23 @@ def parse_args():
|
|
39 |
parser.add_argument('config', help='test config file path')
|
40 |
parser.add_argument('checkpoint', help='checkpoint file')
|
41 |
parser.add_argument('out', help='output result file')
|
|
|
|
|
42 |
parser.add_argument(
|
43 |
'--gpu-collect',
|
44 |
action='store_true',
|
45 |
help='whether to use gpu to collect results')
|
46 |
parser.add_argument('--tmpdir', help='tmp dir for writing some results')
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
parser.add_argument(
|
58 |
'--device', default=None, help='device used for testing. (Deprecated)')
|
59 |
parser.add_argument(
|
@@ -103,8 +95,6 @@ def single_gpu_test(model,
|
|
103 |
observation_ids = []
|
104 |
for i, data in enumerate(data_loader):
|
105 |
with torch.no_grad():
|
106 |
-
#result = model(return_loss=False, **data)
|
107 |
-
#imgs = data['inputs'].cuda()
|
108 |
#data = model.module.data_preprocessor(data, training=False)
|
109 |
imgs = data['inputs'].cuda()
|
110 |
result = model.module.extract_feat(imgs)
|
@@ -112,12 +102,10 @@ def single_gpu_test(model,
|
|
112 |
filenames = [x.img_path for x in data['data_samples']]
|
113 |
obs_ids = [osp.basename(x).split('.')[0].split('-')[1] for x in filenames]
|
114 |
result = list(zip(result[0], obs_ids))
|
115 |
-
#print(result)
|
116 |
|
117 |
batch_size = len(result)
|
118 |
results.extend(result)
|
119 |
|
120 |
-
#batch_size = data['img'].size(0)
|
121 |
prog_bar.update(batch_size)
|
122 |
return results
|
123 |
|
@@ -167,8 +155,8 @@ def main():
|
|
167 |
default_scope = DefaultScope.get_instance('test', scope_name='mmpretrain')
|
168 |
|
169 |
cfg = mmengine.Config.fromfile(args.config) #mmcv.Config.fromfile(args.config)
|
170 |
-
|
171 |
-
|
172 |
|
173 |
# set multi-process settings
|
174 |
setup_multi_processes(cfg)
|
@@ -187,42 +175,12 @@ def main():
|
|
187 |
else:
|
188 |
cfg.gpu_ids = [args.gpu_id]
|
189 |
|
190 |
-
# dataset = build_dataset(cfg.data.test, default_args=dict(test_mode=True))
|
191 |
-
|
192 |
-
# # build the dataloader
|
193 |
-
# # The default loader config
|
194 |
-
# loader_cfg = dict(
|
195 |
-
# # cfg.gpus will be ignored if distributed
|
196 |
-
# num_gpus=len(cfg.gpu_ids),
|
197 |
-
# dist=False,
|
198 |
-
# round_up=True,
|
199 |
-
# )
|
200 |
-
# # The overall dataloader settings
|
201 |
-
# loader_cfg.update({
|
202 |
-
# k: v
|
203 |
-
# for k, v in cfg.data.items() if k not in [
|
204 |
-
# 'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
|
205 |
-
# 'test_dataloader'
|
206 |
-
# ]
|
207 |
-
# })
|
208 |
-
# test_loader_cfg = {
|
209 |
-
# **loader_cfg,
|
210 |
-
# 'shuffle': False, # Not shuffle by default
|
211 |
-
# 'sampler_cfg': None, # Not use sampler by default
|
212 |
-
# **cfg.data.get('test_dataloader', {}),
|
213 |
-
# }
|
214 |
-
# the extra round_up data will be removed during gpu/cpu collect
|
215 |
data_loader = Runner.build_dataloader(cfg.test_dataloader)
|
216 |
|
217 |
# build the model and load checkpoint
|
218 |
model = build_classifier(cfg.model)
|
219 |
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
|
220 |
|
221 |
-
# if 'CLASSES' in checkpoint.get('meta', {}):
|
222 |
-
# CLASSES = checkpoint['meta']['CLASSES']
|
223 |
-
# if CLASSES is None:
|
224 |
-
# CLASSES = dataset.CLASSES
|
225 |
-
|
226 |
if args.device == 'cpu':
|
227 |
model = model.cpu()
|
228 |
else:
|
@@ -231,30 +189,39 @@ def main():
|
|
231 |
assert mmcv.digit_version(mmcv.__version__) >= (1, 4, 4), \
|
232 |
'To test with CPU, please confirm your mmcv version ' \
|
233 |
'is not lower than v1.4.4'
|
234 |
-
|
235 |
outputs = single_gpu_test(model, data_loader)
|
236 |
|
237 |
results = defaultdict(list)
|
238 |
for result, obs_id in outputs:
|
239 |
results[obs_id].append(result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
|
241 |
dropped = 0
|
242 |
total = 0
|
243 |
-
with open(args.out, 'w') as f
|
244 |
-
f.write('
|
245 |
for obs_id, result in results.items():
|
246 |
avg_feats = torch.mean(torch.stack(result, dim=0), dim=0, keepdim=True)
|
247 |
scores = model.module.head(avg_feats)
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
|
|
|
|
|
|
252 |
total += 1
|
253 |
f.write(f'{obs_id},{float(class_id):.1f}\n')
|
254 |
-
f2.write(f'{obs_id}')
|
255 |
-
for s in scores:
|
256 |
-
f2.write(f',{s}')
|
257 |
-
f2.write('\n')
|
258 |
|
259 |
print(f'dropped {dropped} out of {total}')
|
260 |
|
|
|
17 |
|
18 |
from torch.nn import DataParallel
|
19 |
|
20 |
+
from mmengine.config import DictAction
|
|
|
|
|
|
|
|
|
21 |
from mmengine.runner import load_checkpoint
|
22 |
from mmengine.registry import DefaultScope
|
|
|
|
|
|
|
23 |
from mmengine.runner import Runner
|
|
|
24 |
from mmpretrain.models import build_classifier
|
|
|
|
|
25 |
|
26 |
|
27 |
def parse_args():
|
|
|
29 |
parser.add_argument('config', help='test config file path')
|
30 |
parser.add_argument('checkpoint', help='checkpoint file')
|
31 |
parser.add_argument('out', help='output result file')
|
32 |
+
parser.add_argument('--threshold', default=None, type=float, help='open-set threshold')
|
33 |
+
parser.add_argument('--no-scores', action='store_true', help='don\'t write score .csv file')
|
34 |
parser.add_argument(
|
35 |
'--gpu-collect',
|
36 |
action='store_true',
|
37 |
help='whether to use gpu to collect results')
|
38 |
parser.add_argument('--tmpdir', help='tmp dir for writing some results')
|
39 |
+
parser.add_argument(
|
40 |
+
'--cfg-options',
|
41 |
+
nargs='+',
|
42 |
+
action=DictAction,
|
43 |
+
help='override some settings in the used config, the key-value pair '
|
44 |
+
'in xxx=yyy format will be merged into config file. If the value to '
|
45 |
+
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
46 |
+
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
47 |
+
'Note that the quotation marks are necessary and that no white space '
|
48 |
+
'is allowed.')
|
49 |
parser.add_argument(
|
50 |
'--device', default=None, help='device used for testing. (Deprecated)')
|
51 |
parser.add_argument(
|
|
|
95 |
observation_ids = []
|
96 |
for i, data in enumerate(data_loader):
|
97 |
with torch.no_grad():
|
|
|
|
|
98 |
#data = model.module.data_preprocessor(data, training=False)
|
99 |
imgs = data['inputs'].cuda()
|
100 |
result = model.module.extract_feat(imgs)
|
|
|
102 |
filenames = [x.img_path for x in data['data_samples']]
|
103 |
obs_ids = [osp.basename(x).split('.')[0].split('-')[1] for x in filenames]
|
104 |
result = list(zip(result[0], obs_ids))
|
|
|
105 |
|
106 |
batch_size = len(result)
|
107 |
results.extend(result)
|
108 |
|
|
|
109 |
prog_bar.update(batch_size)
|
110 |
return results
|
111 |
|
|
|
155 |
default_scope = DefaultScope.get_instance('test', scope_name='mmpretrain')
|
156 |
|
157 |
cfg = mmengine.Config.fromfile(args.config) #mmcv.Config.fromfile(args.config)
|
158 |
+
if args.cfg_options is not None:
|
159 |
+
cfg.merge_from_dict(args.cfg_options)
|
160 |
|
161 |
# set multi-process settings
|
162 |
setup_multi_processes(cfg)
|
|
|
175 |
else:
|
176 |
cfg.gpu_ids = [args.gpu_id]
|
177 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
data_loader = Runner.build_dataloader(cfg.test_dataloader)
|
179 |
|
180 |
# build the model and load checkpoint
|
181 |
model = build_classifier(cfg.model)
|
182 |
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
|
183 |
|
|
|
|
|
|
|
|
|
|
|
184 |
if args.device == 'cpu':
|
185 |
model = model.cpu()
|
186 |
else:
|
|
|
189 |
assert mmcv.digit_version(mmcv.__version__) >= (1, 4, 4), \
|
190 |
'To test with CPU, please confirm your mmcv version ' \
|
191 |
'is not lower than v1.4.4'
|
192 |
+
|
193 |
outputs = single_gpu_test(model, data_loader)
|
194 |
|
195 |
results = defaultdict(list)
|
196 |
for result, obs_id in outputs:
|
197 |
results[obs_id].append(result)
|
198 |
+
|
199 |
+
if not args.no_scores:
|
200 |
+
with open(args.out + '.scores.csv', 'w') as f2:
|
201 |
+
for obs_id, result in results.items():
|
202 |
+
avg_feats = torch.mean(torch.stack(result, dim=0), dim=0, keepdim=True)
|
203 |
+
scores = model.module.head(avg_feats)
|
204 |
+
f2.write(f'{obs_id}')
|
205 |
+
for s in scores:
|
206 |
+
f2.write(f',{s}')
|
207 |
+
f2.write('\n')
|
208 |
|
209 |
dropped = 0
|
210 |
total = 0
|
211 |
+
with open(args.out, 'w') as f:
|
212 |
+
f.write('observationID,class_id\n')
|
213 |
for obs_id, result in results.items():
|
214 |
avg_feats = torch.mean(torch.stack(result, dim=0), dim=0, keepdim=True)
|
215 |
scores = model.module.head(avg_feats)
|
216 |
+
scores = scores.detach().cpu().numpy()
|
217 |
+
class_id = np.argmax(scores)
|
218 |
+
if args.threshold:
|
219 |
+
max_score = float(torch.max(torch.softmax(torch.from_numpy(scores), dim=0)))
|
220 |
+
if max_score < args.threshold:
|
221 |
+
class_id = -1
|
222 |
+
dropped += 1
|
223 |
total += 1
|
224 |
f.write(f'{obs_id},{float(class_id):.1f}\n')
|
|
|
|
|
|
|
|
|
225 |
|
226 |
print(f'dropped {dropped} out of {total}')
|
227 |
|