Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- configs/dinov2/dinov2_upernet_water.py +13 -0
- configs/ktda/dinov2_b_frozen-fam-fmm.py +18 -0
- configs/ktda/dinov2_b_frozen-fam.py +13 -0
- configs/ktda/experiment_a.py +14 -0
- configs/ktda/experiment_aa.py +46 -0
- configs/ktda/experiment_k.py +14 -0
- configs/ktda/experiment_u.py +15 -0
- configs/ktda/experiment_v.py +26 -0
- configs/ktda/ktda_grass.py +19 -0
- configs/pspnet/pspnet_r101_water.py +15 -0
- configs/pspnet/pspnet_r50.py +13 -0
- configs/segformer/segformer_mit-b0_water.py +14 -0
- ktda/datasets/__init__.py +7 -0
- ktda/datasets/grass.py +55 -0
- ktda/datasets/l8_biome.py +29 -0
- ktda/models/__init__.py +4 -0
- ktda/models/__pycache__/__init__.cpython-311.pyc +0 -0
- ktda/models/adapter/__init__.py +4 -0
- ktda/models/adapter/__pycache__/__init__.cpython-311.pyc +0 -0
- ktda/models/adapter/__pycache__/fam.cpython-311.pyc +0 -0
- ktda/models/adapter/__pycache__/fmm.cpython-311.pyc +0 -0
- ktda/models/adapter/fam.py +37 -0
- ktda/models/adapter/fmm.py +109 -0
- ktda/models/segmentors/__pycache__/__init__.cpython-311.pyc +0 -0
- ktda/models/segmentors/__pycache__/distill_encoder_decoder.cpython-311.pyc +0 -0
- ktda/models/segmentors/distill_encoder_decoder.py +382 -0
- requirements/docs.txt +7 -0
- requirements/optional.txt +22 -0
- requirements/runtime.txt +5 -0
- tools/analysis_tools/analyze_logs.py +130 -0
- tools/analysis_tools/benchmark.py +121 -0
- tools/analysis_tools/confusion_matrix.py +197 -0
- tools/analysis_tools/get_flops.py +126 -0
- tools/analysis_tools/visualization_cam.py +127 -0
- tools/dataset_converters/chase_db1.py +89 -0
- tools/dataset_converters/cityscapes.py +56 -0
- tools/dataset_converters/coco_stuff10k.py +308 -0
- tools/dataset_converters/coco_stuff164k.py +265 -0
- tools/dataset_converters/hrf.py +112 -0
- tools/dataset_converters/isaid.py +246 -0
- tools/dataset_converters/levircd.py +99 -0
- tools/dataset_converters/loveda.py +73 -0
- tools/dataset_converters/nyu.py +89 -0
- tools/dataset_converters/pascal_context.py +87 -0
- tools/dataset_converters/potsdam.py +158 -0
- tools/dataset_converters/refuge.py +110 -0
- tools/dataset_converters/stare.py +167 -0
- tools/dataset_converters/synapse.py +155 -0
- tools/dataset_converters/voc_aug.py +92 -0
- tools/dataset_tools/create_dataset.py +185 -0
configs/dinov2/dinov2_upernet_water.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = [
|
2 |
+
"../_base_/models/dinov2_upernet.py",
|
3 |
+
"../_base_/datasets/water.py",
|
4 |
+
"../_base_/default_runtime.py",
|
5 |
+
"../_base_/schedules/water_schedule.py",
|
6 |
+
]
|
7 |
+
|
8 |
+
data_preprocessor = dict(size=(512, 512))
|
9 |
+
model = dict(
|
10 |
+
data_preprocessor=data_preprocessor,
|
11 |
+
decode_head=dict(num_classes=6),
|
12 |
+
auxiliary_head=dict(num_classes=6)
|
13 |
+
)
|
configs/ktda/dinov2_b_frozen-fam-fmm.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = [
|
2 |
+
"../_base_/models/ktda.py",
|
3 |
+
"../_base_/datasets/grass.py",
|
4 |
+
"../_base_/default_runtime.py",
|
5 |
+
"../_base_/schedules/grass_schedule.py",
|
6 |
+
]
|
7 |
+
|
8 |
+
data_preprocessor = dict(size=(256, 256))
|
9 |
+
model = dict(
|
10 |
+
data_preprocessor=data_preprocessor,
|
11 |
+
decode_head=dict(
|
12 |
+
num_classes=5,
|
13 |
+
),
|
14 |
+
auxiliary_head=dict(
|
15 |
+
num_classes=5,
|
16 |
+
),
|
17 |
+
fmm=dict(type="FMM", in_channels=[768, 768, 768, 768]),
|
18 |
+
)
|
configs/ktda/dinov2_b_frozen-fam.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = [
|
2 |
+
"../_base_/models/ktda.py",
|
3 |
+
"../_base_/datasets/grass.py",
|
4 |
+
"../_base_/default_runtime.py",
|
5 |
+
"../_base_/schedules/grass_schedule.py",
|
6 |
+
]
|
7 |
+
|
8 |
+
data_preprocessor = dict(size=(256, 256))
|
9 |
+
model = dict(
|
10 |
+
data_preprocessor=data_preprocessor,
|
11 |
+
decode_head=dict(num_classes=5),
|
12 |
+
auxiliary_head=dict(num_classes=5)
|
13 |
+
)
|
configs/ktda/experiment_a.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = [
|
2 |
+
"../_base_/models/ktda.py",
|
3 |
+
"../_base_/datasets/grass.py",
|
4 |
+
"../_base_/default_runtime.py",
|
5 |
+
"../_base_/schedules/grass_schedule.py",
|
6 |
+
]
|
7 |
+
|
8 |
+
data_preprocessor = dict(size=(256, 256))
|
9 |
+
model = dict(
|
10 |
+
student_training=False,
|
11 |
+
data_preprocessor=data_preprocessor,
|
12 |
+
decode_head=dict(num_classes=5),
|
13 |
+
auxiliary_head=dict(num_classes=5)
|
14 |
+
)
|
configs/ktda/experiment_aa.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = [
|
2 |
+
"../_base_/models/convnextv2_femto_vit_segformer_vegseg.py",
|
3 |
+
"../_base_/datasets/grass.py",
|
4 |
+
"../_base_/default_runtime.py",
|
5 |
+
"../_base_/schedules/grass_schedule.py",
|
6 |
+
]
|
7 |
+
|
8 |
+
data_preprocessor = dict(size=(256, 256))
|
9 |
+
model = dict(
|
10 |
+
teach_backbone=dict(
|
11 |
+
type="mmpretrain.VisionTransformer",
|
12 |
+
arch="large",
|
13 |
+
frozen_stages=24,
|
14 |
+
img_size=256,
|
15 |
+
patch_size=14,
|
16 |
+
layer_scale_init_value=1e-5,
|
17 |
+
out_indices=(7, 11, 15, 23),
|
18 |
+
out_type="featmap",
|
19 |
+
init_cfg=dict(
|
20 |
+
type="Pretrained",
|
21 |
+
checkpoint="checkpoints/dinov2-large.pth",
|
22 |
+
prefix="backbone",
|
23 |
+
),
|
24 |
+
),
|
25 |
+
fam=dict(out_channels=1024),
|
26 |
+
decode_head=dict(in_channels=[1024, 1024, 1024, 1024], num_classes=5),
|
27 |
+
data_preprocessor=data_preprocessor,
|
28 |
+
auxiliary_head=[
|
29 |
+
dict(
|
30 |
+
type="FCNHead",
|
31 |
+
in_channels=1024,
|
32 |
+
in_index=i,
|
33 |
+
channels=256,
|
34 |
+
num_convs=1,
|
35 |
+
concat_input=False,
|
36 |
+
dropout_ratio=0.1,
|
37 |
+
num_classes=5,
|
38 |
+
norm_cfg=dict(type="SyncBN", requires_grad=True),
|
39 |
+
align_corners=False,
|
40 |
+
loss_decode=dict(
|
41 |
+
type="CrossEntropyLoss", use_sigmoid=False, loss_weight=0.4
|
42 |
+
),
|
43 |
+
)
|
44 |
+
for i in range(4)
|
45 |
+
],
|
46 |
+
)
|
configs/ktda/experiment_k.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = [
|
2 |
+
"../_base_/models/ktda.py",
|
3 |
+
"../_base_/datasets/grass.py",
|
4 |
+
"../_base_/default_runtime.py",
|
5 |
+
"../_base_/schedules/grass_schedule.py",
|
6 |
+
]
|
7 |
+
|
8 |
+
data_preprocessor = dict(size=(256, 256))
|
9 |
+
model = dict(
|
10 |
+
data_preprocessor=data_preprocessor,
|
11 |
+
decode_head=dict(num_classes=5),
|
12 |
+
auxiliary_head=dict(num_classes=5),
|
13 |
+
fmm=dict(type="FMM", in_channels=[768, 768, 768, 768],mlp_nums=4),
|
14 |
+
)
|
configs/ktda/experiment_u.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = [
|
2 |
+
"../_base_/models/ktda.py",
|
3 |
+
"../_base_/datasets/grass.py",
|
4 |
+
"../_base_/default_runtime.py",
|
5 |
+
"../_base_/schedules/grass_schedule.py",
|
6 |
+
]
|
7 |
+
|
8 |
+
data_preprocessor = dict(size=(256, 256))
|
9 |
+
model = dict(
|
10 |
+
data_preprocessor=data_preprocessor,
|
11 |
+
decode_head=dict(num_classes=5),
|
12 |
+
auxiliary_head=dict(num_classes=5),
|
13 |
+
neck=None,
|
14 |
+
fmm=dict(type="FMM", in_channels=[768, 768, 768, 768]),
|
15 |
+
)
|
configs/ktda/experiment_v.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = [
|
2 |
+
"../_base_/models/ktda.py",
|
3 |
+
"../_base_/datasets/grass.py",
|
4 |
+
"../_base_/default_runtime.py",
|
5 |
+
"../_base_/schedules/grass_schedule.py",
|
6 |
+
]
|
7 |
+
|
8 |
+
data_preprocessor = dict(size=(256, 256))
|
9 |
+
model = dict(
|
10 |
+
data_preprocessor=data_preprocessor,
|
11 |
+
decode_head=dict(
|
12 |
+
_delete_=True,
|
13 |
+
type="SegformerHead",
|
14 |
+
in_channels=[768, 768, 768, 768],
|
15 |
+
in_index=[0, 1, 2, 3],
|
16 |
+
channels=256,
|
17 |
+
dropout_ratio=0.1,
|
18 |
+
num_classes=5,
|
19 |
+
norm_cfg=dict(type="SyncBN", requires_grad=True),
|
20 |
+
align_corners=False,
|
21 |
+
loss_decode=dict(type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0),
|
22 |
+
),
|
23 |
+
auxiliary_head=dict(num_classes=5),
|
24 |
+
neck=None,
|
25 |
+
fmm=dict(type="FMM", in_channels=[768, 768, 768, 768]),
|
26 |
+
)
|
configs/ktda/ktda_grass.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = [
|
2 |
+
"../_base_/models/ktda.py",
|
3 |
+
"../_base_/datasets/grass.py",
|
4 |
+
"../_base_/default_runtime.py",
|
5 |
+
"../_base_/schedules/grass_schedule.py",
|
6 |
+
]
|
7 |
+
|
8 |
+
data_preprocessor = dict(size=(256, 256))
|
9 |
+
model = dict(
|
10 |
+
data_preprocessor=data_preprocessor,
|
11 |
+
decode_head=dict(num_classes=5),
|
12 |
+
auxiliary_head=dict(num_classes=5),
|
13 |
+
fmm=dict(
|
14 |
+
type="FMM",
|
15 |
+
in_channels=[768, 768, 768, 768],
|
16 |
+
model_type="vitBlock",
|
17 |
+
mlp_nums=4,
|
18 |
+
),
|
19 |
+
)
|
configs/pspnet/pspnet_r101_water.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = [
|
2 |
+
"../_base_/models/pspnet_r50-d8.py",
|
3 |
+
"../_base_/datasets/water.py",
|
4 |
+
"../_base_/default_runtime.py",
|
5 |
+
"../_base_/schedules/water_schedule.py",
|
6 |
+
]
|
7 |
+
|
8 |
+
data_preprocessor = dict(size=(512, 512))
|
9 |
+
model = dict(
|
10 |
+
data_preprocessor=data_preprocessor,
|
11 |
+
pretrained='open-mmlab://resnet101_v1c',
|
12 |
+
backbone=dict(depth=101),
|
13 |
+
decode_head=dict(num_classes=6),
|
14 |
+
auxiliary_head=dict(num_classes=6)
|
15 |
+
)
|
configs/pspnet/pspnet_r50.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = [
|
2 |
+
"../_base_/models/pspnet_r50-d8.py",
|
3 |
+
"../_base_/datasets/grass.py",
|
4 |
+
"../_base_/default_runtime.py",
|
5 |
+
"../_base_/schedules/grass_schedule.py",
|
6 |
+
]
|
7 |
+
|
8 |
+
data_preprocessor = dict(size=(256, 256))
|
9 |
+
model = dict(
|
10 |
+
data_preprocessor=data_preprocessor,
|
11 |
+
decode_head=dict(num_classes=5),
|
12 |
+
auxiliary_head=dict(num_classes=5)
|
13 |
+
)
|
configs/segformer/segformer_mit-b0_water.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = [
|
2 |
+
"../_base_/models/segformer_mit-b0.py",
|
3 |
+
"../_base_/datasets/water.py",
|
4 |
+
"../_base_/default_runtime.py",
|
5 |
+
"../_base_/schedules/water_schedule.py",
|
6 |
+
]
|
7 |
+
|
8 |
+
data_preprocessor = dict(size=(512, 512))
|
9 |
+
checkpoint = "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segformer/mit_b0_20220624-7e0fe6dd.pth" # noqa
|
10 |
+
model = dict(
|
11 |
+
data_preprocessor=data_preprocessor,
|
12 |
+
backbone=dict(init_cfg=dict(type="Pretrained", checkpoint=checkpoint)),
|
13 |
+
decode_head=dict(num_classes=6),
|
14 |
+
)
|
ktda/datasets/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .grass import GrassDataset
|
2 |
+
from .l8_biome import L8BIOMEDataset
|
3 |
+
|
4 |
+
__all__ = [
|
5 |
+
"GrassDataset",
|
6 |
+
"L8BIOMEDataset"
|
7 |
+
]
|
ktda/datasets/grass.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import os.path as osp
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
import mmengine.fileio as fileio
|
6 |
+
|
7 |
+
from mmseg.registry import DATASETS
|
8 |
+
from mmseg.datasets import BaseSegDataset
|
9 |
+
|
10 |
+
|
11 |
+
@DATASETS.register_module()
|
12 |
+
class GrassDataset(BaseSegDataset):
|
13 |
+
"""grass segmentation dataset. The file structure should be.
|
14 |
+
|
15 |
+
.. code-block:: none
|
16 |
+
|
17 |
+
├── data
|
18 |
+
│ ├── grass
|
19 |
+
│ │ ├── img_dir
|
20 |
+
│ │ │ ├── train
|
21 |
+
│ │ │ │ ├──0.tif
|
22 |
+
│ │ │ │ ├──...
|
23 |
+
│ │ │ ├── val
|
24 |
+
│ │ │ │ ├──9.tif
|
25 |
+
│ │ │ │ ├──...
|
26 |
+
│ │ ├── ann_dir
|
27 |
+
│ │ │ ├── train
|
28 |
+
│ │ │ │ ├──0.png
|
29 |
+
│ │ │ │ ├──...
|
30 |
+
│ │ │ ├── val
|
31 |
+
│ │ │ │ ├──9.png
|
32 |
+
│ │ │ │ ├──...
|
33 |
+
"""
|
34 |
+
|
35 |
+
METAINFO = dict(
|
36 |
+
classes=("low", "middle-low", "middle", "middle-high", "high"),
|
37 |
+
palette=[
|
38 |
+
[185, 101, 71],
|
39 |
+
[248, 202, 155],
|
40 |
+
[211, 232, 158],
|
41 |
+
[138, 191, 104],
|
42 |
+
[92, 144, 77],
|
43 |
+
],
|
44 |
+
)
|
45 |
+
|
46 |
+
def __init__(self,
|
47 |
+
img_suffix='.tif',
|
48 |
+
seg_map_suffix='.png',
|
49 |
+
reduce_zero_label=False,
|
50 |
+
**kwargs) -> None:
|
51 |
+
super().__init__(
|
52 |
+
img_suffix=img_suffix,
|
53 |
+
seg_map_suffix=seg_map_suffix,
|
54 |
+
reduce_zero_label=reduce_zero_label,
|
55 |
+
**kwargs)
|
ktda/datasets/l8_biome.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mmseg.registry import DATASETS
|
2 |
+
from mmseg.datasets import BaseSegDataset
|
3 |
+
|
4 |
+
|
5 |
+
@DATASETS.register_module()
|
6 |
+
class L8BIOMEDataset(BaseSegDataset):
|
7 |
+
METAINFO = dict(
|
8 |
+
classes=("Clear", "Cloud Shadow", "Thin Cloud", "Cloud"),
|
9 |
+
palette=[
|
10 |
+
[79, 253, 199],
|
11 |
+
[221, 53, 223],
|
12 |
+
[251, 255, 41],
|
13 |
+
[77, 2, 115],
|
14 |
+
],
|
15 |
+
)
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
img_suffix=".png",
|
20 |
+
seg_map_suffix=".png",
|
21 |
+
reduce_zero_label=False,
|
22 |
+
**kwargs
|
23 |
+
) -> None:
|
24 |
+
super().__init__(
|
25 |
+
img_suffix=img_suffix,
|
26 |
+
seg_map_suffix=seg_map_suffix,
|
27 |
+
reduce_zero_label=reduce_zero_label,
|
28 |
+
**kwargs
|
29 |
+
)
|
ktda/models/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .segmentors import DistillEncoderDecoder
|
2 |
+
from .adapter import FAM,FMM
|
3 |
+
|
4 |
+
__all__ = ["DistillEncoderDecoder", "FAM","FMM"]
|
ktda/models/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (338 Bytes). View file
|
|
ktda/models/adapter/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .fam import FAM
|
2 |
+
from .fmm import FMM
|
3 |
+
|
4 |
+
__all__ = ["FAM", "FMM"]
|
ktda/models/adapter/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (288 Bytes). View file
|
|
ktda/models/adapter/__pycache__/fam.cpython-311.pyc
ADDED
Binary file (2.86 kB). View file
|
|
ktda/models/adapter/__pycache__/fmm.cpython-311.pyc
ADDED
Binary file (5.88 kB). View file
|
|
ktda/models/adapter/fam.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mmseg.registry import MODELS
|
2 |
+
from mmengine.model import BaseModule
|
3 |
+
from torch import nn as nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
from timm.models.layers import trunc_normal_
|
6 |
+
|
7 |
+
|
8 |
+
@MODELS.register_module()
|
9 |
+
class FAM(BaseModule):
|
10 |
+
def __init__(self, in_channels, out_channels, output_size,init_cfg=None):
|
11 |
+
super().__init__(init_cfg)
|
12 |
+
self.convert = nn.ModuleList()
|
13 |
+
self.output_size = output_size
|
14 |
+
if isinstance(out_channels, int):
|
15 |
+
out_channels = [out_channels] * len(in_channels)
|
16 |
+
for in_channel, out_channel in zip(in_channels, out_channels):
|
17 |
+
self.convert.append(
|
18 |
+
nn.Conv2d(in_channel, out_channel, kernel_size=1),
|
19 |
+
)
|
20 |
+
|
21 |
+
self.apply(self._init_weights)
|
22 |
+
|
23 |
+
def _init_weights(self, m):
|
24 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
25 |
+
trunc_normal_(m.weight, std=.02)
|
26 |
+
nn.init.constant_(m.bias, 0)
|
27 |
+
|
28 |
+
|
29 |
+
def forward(self, inputs):
|
30 |
+
outs = []
|
31 |
+
for index, x in enumerate(inputs):
|
32 |
+
x = self.convert[index](x)
|
33 |
+
x = F.interpolate(
|
34 |
+
x, size=(self.output_size,self.output_size), align_corners=False, mode="bilinear"
|
35 |
+
)
|
36 |
+
outs.append(x)
|
37 |
+
return tuple(outs)
|
ktda/models/adapter/fmm.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mmseg.registry import MODELS
|
2 |
+
from mmengine.model import BaseModule
|
3 |
+
from torch import nn as nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
from typing import Callable, Optional
|
6 |
+
from torch import Tensor
|
7 |
+
from timm.models.layers import trunc_normal_
|
8 |
+
from timm.models.vision_transformer import Block as TransformerBlock
|
9 |
+
|
10 |
+
|
11 |
+
class Mlp(nn.Module):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
in_features: int,
|
15 |
+
hidden_features: Optional[int] = None,
|
16 |
+
out_features: Optional[int] = None,
|
17 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
18 |
+
drop: float = 0.0,
|
19 |
+
bias: bool = True,
|
20 |
+
) -> None:
|
21 |
+
super().__init__()
|
22 |
+
out_features = out_features or in_features
|
23 |
+
hidden_features = hidden_features or in_features
|
24 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
25 |
+
self.act = act_layer()
|
26 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
27 |
+
self.drop = nn.Dropout(drop)
|
28 |
+
|
29 |
+
def forward(self, x: Tensor) -> Tensor:
|
30 |
+
x = self.fc1(x)
|
31 |
+
x = self.act(x)
|
32 |
+
x = self.drop(x)
|
33 |
+
x = self.fc2(x)
|
34 |
+
x = self.drop(x)
|
35 |
+
return x
|
36 |
+
|
37 |
+
|
38 |
+
@MODELS.register_module()
|
39 |
+
class FMM(BaseModule):
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
in_channels,
|
43 |
+
rank_dim=4,
|
44 |
+
mlp_nums=1,
|
45 |
+
model_type="mlp",
|
46 |
+
num_heads=8,
|
47 |
+
mlp_ratio=4,
|
48 |
+
qkv_bias=True,
|
49 |
+
qk_norm=False,
|
50 |
+
init_values=None,
|
51 |
+
proj_drop_rate: float = 0.0,
|
52 |
+
attn_drop_rate: float = 0.0,
|
53 |
+
init_cfg=None,
|
54 |
+
):
|
55 |
+
super().__init__(init_cfg)
|
56 |
+
self.adapters = nn.ModuleList()
|
57 |
+
if model_type == "mlp":
|
58 |
+
for in_channel in in_channels:
|
59 |
+
mlp_list = []
|
60 |
+
for _ in range(mlp_nums):
|
61 |
+
mlp_list.append(
|
62 |
+
Mlp(
|
63 |
+
in_channel,
|
64 |
+
hidden_features=in_channel // rank_dim,
|
65 |
+
out_features=in_channel,
|
66 |
+
)
|
67 |
+
)
|
68 |
+
mlp_model = nn.Sequential(*mlp_list)
|
69 |
+
self.adapters.append(mlp_model)
|
70 |
+
|
71 |
+
elif model_type == "vitBlock":
|
72 |
+
for in_channel in in_channels:
|
73 |
+
model_list = []
|
74 |
+
for _ in range(mlp_nums):
|
75 |
+
model_list.append(
|
76 |
+
TransformerBlock(
|
77 |
+
in_channel,
|
78 |
+
num_heads=num_heads,
|
79 |
+
mlp_ratio=mlp_ratio,
|
80 |
+
qkv_bias=qkv_bias,
|
81 |
+
qk_norm=qk_norm,
|
82 |
+
init_values=init_values,
|
83 |
+
proj_drop=proj_drop_rate,
|
84 |
+
attn_drop=attn_drop_rate,
|
85 |
+
)
|
86 |
+
)
|
87 |
+
self.adapters.append(nn.Sequential(*model_list))
|
88 |
+
|
89 |
+
else:
|
90 |
+
raise ValueError(f"model type must in ['mlp','vitBlock'],actually is {model_type}")
|
91 |
+
|
92 |
+
self.apply(self._init_weights)
|
93 |
+
|
94 |
+
def _init_weights(self, m):
|
95 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
96 |
+
trunc_normal_(m.weight, std=0.02)
|
97 |
+
nn.init.constant_(m.bias, 0)
|
98 |
+
|
99 |
+
def forward(self, inputs):
|
100 |
+
outs = []
|
101 |
+
for index, x in enumerate(inputs):
|
102 |
+
B, C, H, W = x.shape
|
103 |
+
x = x.permute(0, 2, 3, 1)
|
104 |
+
x = x.reshape(B, -1, C)
|
105 |
+
x = self.adapters[index](x)
|
106 |
+
x = x.reshape(B, H, W, C)
|
107 |
+
x = x.permute(0, 3, 1, 2)
|
108 |
+
outs.append(x)
|
109 |
+
return tuple(outs)
|
ktda/models/segmentors/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (274 Bytes). View file
|
|
ktda/models/segmentors/__pycache__/distill_encoder_decoder.cpython-311.pyc
ADDED
Binary file (19.8 kB). View file
|
|
ktda/models/segmentors/distill_encoder_decoder.py
ADDED
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import logging
|
3 |
+
from typing import List, Optional
|
4 |
+
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from mmengine.logging import print_log
|
8 |
+
from torch import Tensor
|
9 |
+
|
10 |
+
from mmseg.registry import MODELS
|
11 |
+
from mmseg.utils import (
|
12 |
+
ConfigType,
|
13 |
+
OptConfigType,
|
14 |
+
OptMultiConfig,
|
15 |
+
OptSampleList,
|
16 |
+
SampleList,
|
17 |
+
add_prefix,
|
18 |
+
)
|
19 |
+
from mmseg.models import BaseSegmentor
|
20 |
+
|
21 |
+
|
22 |
+
@MODELS.register_module()
|
23 |
+
class DistillEncoderDecoder(BaseSegmentor):
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
backbone: ConfigType,
|
28 |
+
teach_backbone: ConfigType,
|
29 |
+
decode_head: ConfigType,
|
30 |
+
neck: OptConfigType = None,
|
31 |
+
auxiliary_head: OptConfigType = None,
|
32 |
+
fam: OptConfigType = None,
|
33 |
+
fmm: OptConfigType = None,
|
34 |
+
train_cfg: OptConfigType = None,
|
35 |
+
test_cfg: OptConfigType = None,
|
36 |
+
data_preprocessor: OptConfigType = None,
|
37 |
+
pretrained: Optional[str] = None,
|
38 |
+
student_training=True,
|
39 |
+
temperature=1.0,
|
40 |
+
alpha=0.5,
|
41 |
+
fuse=False,
|
42 |
+
init_cfg: OptMultiConfig = None,
|
43 |
+
):
|
44 |
+
super().__init__(data_preprocessor=data_preprocessor, init_cfg=init_cfg)
|
45 |
+
|
46 |
+
self.temperature = temperature
|
47 |
+
self.alpha = alpha
|
48 |
+
self.student_training = student_training
|
49 |
+
self.fuse = fuse
|
50 |
+
|
51 |
+
if pretrained is not None:
|
52 |
+
assert (
|
53 |
+
backbone.get("pretrained") is None
|
54 |
+
), "both backbone and segmentor set pretrained weight"
|
55 |
+
assert (
|
56 |
+
teach_backbone.get("pretrained") is None
|
57 |
+
), "both teach backbone and segmentor set pretrained weight"
|
58 |
+
backbone.pretrained = pretrained
|
59 |
+
teach_backbone.pretrained = pretrained
|
60 |
+
self.backbone = MODELS.build(backbone)
|
61 |
+
self.teach_backbone = MODELS.build(teach_backbone)
|
62 |
+
if neck is not None:
|
63 |
+
self.neck = MODELS.build(neck)
|
64 |
+
|
65 |
+
self.fam = nn.Identity()
|
66 |
+
self.fmm = nn.Identity()
|
67 |
+
if fam is not None:
|
68 |
+
self.fam = MODELS.build(fam)
|
69 |
+
if fmm is not None:
|
70 |
+
self.fmm = MODELS.build(fmm)
|
71 |
+
self._init_decode_head(decode_head)
|
72 |
+
self._init_auxiliary_head(auxiliary_head)
|
73 |
+
|
74 |
+
self.train_cfg = train_cfg
|
75 |
+
self.test_cfg = test_cfg
|
76 |
+
|
77 |
+
assert self.with_decode_head
|
78 |
+
|
79 |
+
def _init_decode_head(self, decode_head: ConfigType) -> None:
|
80 |
+
"""Initialize ``decode_head``"""
|
81 |
+
self.decode_head = MODELS.build(decode_head)
|
82 |
+
self.align_corners = self.decode_head.align_corners
|
83 |
+
self.num_classes = self.decode_head.num_classes
|
84 |
+
self.out_channels = self.decode_head.out_channels
|
85 |
+
|
86 |
+
def _init_auxiliary_head(self, auxiliary_head: ConfigType) -> None:
|
87 |
+
"""Initialize ``auxiliary_head``"""
|
88 |
+
if auxiliary_head is not None:
|
89 |
+
if isinstance(auxiliary_head, list):
|
90 |
+
self.auxiliary_head = nn.ModuleList()
|
91 |
+
for head_cfg in auxiliary_head:
|
92 |
+
self.auxiliary_head.append(MODELS.build(head_cfg))
|
93 |
+
else:
|
94 |
+
self.auxiliary_head = MODELS.build(auxiliary_head)
|
95 |
+
|
96 |
+
def fuse_features(self,features):
|
97 |
+
x = features[0]
|
98 |
+
for index,feature in enumerate(features):
|
99 |
+
if index == 0:
|
100 |
+
continue
|
101 |
+
x += feature
|
102 |
+
x = [x]
|
103 |
+
return tuple(x)
|
104 |
+
|
105 |
+
def extract_feat(self, inputs: Tensor) -> List[Tensor]:
|
106 |
+
"""Extract features from images."""
|
107 |
+
x = self.backbone(inputs)
|
108 |
+
x = self.fam(x)
|
109 |
+
if self.fuse:
|
110 |
+
x = self.fuse_features(x)
|
111 |
+
if self.with_neck:
|
112 |
+
x = self.neck(x)
|
113 |
+
x = self.fmm(x)
|
114 |
+
return x
|
115 |
+
|
116 |
+
def encode_decode(self, inputs: Tensor, batch_img_metas: List[dict]) -> Tensor:
|
117 |
+
"""Encode images with backbone and decode into a semantic segmentation
|
118 |
+
map of the same size as input."""
|
119 |
+
x = self.extract_feat(inputs)
|
120 |
+
seg_logits = self.decode_head.predict(x, batch_img_metas, self.test_cfg)
|
121 |
+
|
122 |
+
return seg_logits
|
123 |
+
|
124 |
+
def _decode_head_forward_train(
|
125 |
+
self, inputs: List[Tensor], data_samples: SampleList
|
126 |
+
) -> dict:
|
127 |
+
"""Run forward function and calculate loss for decode head in
|
128 |
+
training."""
|
129 |
+
losses = dict()
|
130 |
+
loss_decode = self.decode_head.loss(inputs, data_samples, self.train_cfg)
|
131 |
+
|
132 |
+
losses.update(add_prefix(loss_decode, "decode"))
|
133 |
+
return losses
|
134 |
+
|
135 |
+
def _auxiliary_head_forward_train(
|
136 |
+
self, inputs: List[Tensor], data_samples: SampleList
|
137 |
+
) -> dict:
|
138 |
+
"""Run forward function and calculate loss for auxiliary head in
|
139 |
+
training."""
|
140 |
+
losses = dict()
|
141 |
+
if isinstance(self.auxiliary_head, nn.ModuleList):
|
142 |
+
for idx, aux_head in enumerate(self.auxiliary_head):
|
143 |
+
loss_aux = aux_head.loss(inputs, data_samples, self.train_cfg)
|
144 |
+
for key in loss_aux.keys():
|
145 |
+
loss_aux[key] = loss_aux[key] / len(self.auxiliary_head)
|
146 |
+
losses.update(add_prefix(loss_aux, f"aux_{idx}"))
|
147 |
+
else:
|
148 |
+
loss_aux = self.auxiliary_head.loss(inputs, data_samples, self.train_cfg)
|
149 |
+
losses.update(add_prefix(loss_aux, "aux"))
|
150 |
+
|
151 |
+
return losses
|
152 |
+
|
153 |
+
def calculate_diltill_loss(self, inputs):
|
154 |
+
student_feats = self.backbone(inputs)
|
155 |
+
student_feats = self.fam(student_feats)
|
156 |
+
teach_feats = self.teach_backbone(inputs)
|
157 |
+
|
158 |
+
if self.fuse:
|
159 |
+
student_feats = self.fuse_features(student_feats)
|
160 |
+
teach_feats = self.fuse_features(teach_feats)
|
161 |
+
|
162 |
+
total_loss = 0.0
|
163 |
+
for student_feat, teach_feat in zip(student_feats, teach_feats):
|
164 |
+
student_prob = F.softmax(student_feat / self.temperature, dim=-1)
|
165 |
+
teach_prob = F.softmax(teach_feat / self.temperature, dim=-1)
|
166 |
+
kl_loss = F.kl_div(
|
167 |
+
student_prob.log(), teach_prob, reduction="batchmean"
|
168 |
+
) * (self.temperature**2)
|
169 |
+
mse_loss = F.mse_loss(student_feat, teach_feat, reduction="mean")
|
170 |
+
loss = self.alpha * kl_loss + (1 - self.alpha) * mse_loss
|
171 |
+
total_loss += loss
|
172 |
+
|
173 |
+
avg_loss = total_loss / len(student_feats)
|
174 |
+
if self.alpha == 0:
|
175 |
+
avg_loss = avg_loss * 0.5
|
176 |
+
return avg_loss
|
177 |
+
|
178 |
+
def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
|
179 |
+
"""Calculate losses from a batch of inputs and data samples.
|
180 |
+
|
181 |
+
Args:
|
182 |
+
inputs (Tensor): Input images.
|
183 |
+
data_samples (list[:obj:`SegDataSample`]): The seg data samples.
|
184 |
+
It usually includes information such as `metainfo` and
|
185 |
+
`gt_sem_seg`.
|
186 |
+
|
187 |
+
Returns:
|
188 |
+
dict[str, Tensor]: a dictionary of loss components
|
189 |
+
"""
|
190 |
+
|
191 |
+
x = self.extract_feat(inputs)
|
192 |
+
|
193 |
+
losses = dict()
|
194 |
+
|
195 |
+
loss_decode = self._decode_head_forward_train(x, data_samples)
|
196 |
+
losses.update(loss_decode)
|
197 |
+
if self.student_training:
|
198 |
+
losses["distill_loss"] = self.calculate_diltill_loss(inputs)
|
199 |
+
if self.with_auxiliary_head:
|
200 |
+
loss_aux = self._auxiliary_head_forward_train(x, data_samples)
|
201 |
+
losses.update(loss_aux)
|
202 |
+
|
203 |
+
return losses
|
204 |
+
|
205 |
+
def predict(self, inputs: Tensor, data_samples: OptSampleList = None) -> SampleList:
|
206 |
+
"""Predict results from a batch of inputs and data samples with post-
|
207 |
+
processing.
|
208 |
+
|
209 |
+
Args:
|
210 |
+
inputs (Tensor): Inputs with shape (N, C, H, W).
|
211 |
+
data_samples (List[:obj:`SegDataSample`], optional): The seg data
|
212 |
+
samples. It usually includes information such as `metainfo`
|
213 |
+
and `gt_sem_seg`.
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
list[:obj:`SegDataSample`]: Segmentation results of the
|
217 |
+
input images. Each SegDataSample usually contain:
|
218 |
+
|
219 |
+
- ``pred_sem_seg``(PixelData): Prediction of semantic segmentation.
|
220 |
+
- ``seg_logits``(PixelData): Predicted logits of semantic
|
221 |
+
segmentation before normalization.
|
222 |
+
"""
|
223 |
+
if data_samples is not None:
|
224 |
+
batch_img_metas = [data_sample.metainfo for data_sample in data_samples]
|
225 |
+
else:
|
226 |
+
batch_img_metas = [
|
227 |
+
dict(
|
228 |
+
ori_shape=inputs.shape[2:],
|
229 |
+
img_shape=inputs.shape[2:],
|
230 |
+
pad_shape=inputs.shape[2:],
|
231 |
+
padding_size=[0, 0, 0, 0],
|
232 |
+
)
|
233 |
+
] * inputs.shape[0]
|
234 |
+
|
235 |
+
seg_logits = self.inference(inputs, batch_img_metas)
|
236 |
+
|
237 |
+
return self.postprocess_result(seg_logits, data_samples)
|
238 |
+
|
239 |
+
def _forward(self, inputs: Tensor, data_samples: OptSampleList = None) -> Tensor:
|
240 |
+
"""Network forward process.
|
241 |
+
|
242 |
+
Args:
|
243 |
+
inputs (Tensor): Inputs with shape (N, C, H, W).
|
244 |
+
data_samples (List[:obj:`SegDataSample`]): The seg
|
245 |
+
data samples. It usually includes information such
|
246 |
+
as `metainfo` and `gt_sem_seg`.
|
247 |
+
|
248 |
+
Returns:
|
249 |
+
Tensor: Forward output of model without any post-processes.
|
250 |
+
"""
|
251 |
+
x = self.extract_feat(inputs)
|
252 |
+
return self.decode_head.forward(x)
|
253 |
+
|
254 |
+
def slide_inference(self, inputs: Tensor, batch_img_metas: List[dict]) -> Tensor:
|
255 |
+
"""Inference by sliding-window with overlap.
|
256 |
+
|
257 |
+
If h_crop > h_img or w_crop > w_img, the small patch will be used to
|
258 |
+
decode without padding.
|
259 |
+
|
260 |
+
Args:
|
261 |
+
inputs (tensor): the tensor should have a shape NxCxHxW,
|
262 |
+
which contains all images in the batch.
|
263 |
+
batch_img_metas (List[dict]): List of image metainfo where each may
|
264 |
+
also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
|
265 |
+
'ori_shape', and 'pad_shape'.
|
266 |
+
For details on the values of these keys see
|
267 |
+
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
Tensor: The segmentation results, seg_logits from model of each
|
271 |
+
input image.
|
272 |
+
"""
|
273 |
+
|
274 |
+
h_stride, w_stride = self.test_cfg.stride
|
275 |
+
h_crop, w_crop = self.test_cfg.crop_size
|
276 |
+
batch_size, _, h_img, w_img = inputs.size()
|
277 |
+
out_channels = self.out_channels
|
278 |
+
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
|
279 |
+
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
|
280 |
+
preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img))
|
281 |
+
count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img))
|
282 |
+
for h_idx in range(h_grids):
|
283 |
+
for w_idx in range(w_grids):
|
284 |
+
y1 = h_idx * h_stride
|
285 |
+
x1 = w_idx * w_stride
|
286 |
+
y2 = min(y1 + h_crop, h_img)
|
287 |
+
x2 = min(x1 + w_crop, w_img)
|
288 |
+
y1 = max(y2 - h_crop, 0)
|
289 |
+
x1 = max(x2 - w_crop, 0)
|
290 |
+
crop_img = inputs[:, :, y1:y2, x1:x2]
|
291 |
+
# change the image shape to patch shape
|
292 |
+
batch_img_metas[0]["img_shape"] = crop_img.shape[2:]
|
293 |
+
# the output of encode_decode is seg logits tensor map
|
294 |
+
# with shape [N, C, H, W]
|
295 |
+
crop_seg_logit = self.encode_decode(crop_img, batch_img_metas)
|
296 |
+
preds += F.pad(
|
297 |
+
crop_seg_logit,
|
298 |
+
(
|
299 |
+
int(x1),
|
300 |
+
int(preds.shape[3] - x2),
|
301 |
+
int(y1),
|
302 |
+
int(preds.shape[2] - y2),
|
303 |
+
),
|
304 |
+
)
|
305 |
+
|
306 |
+
count_mat[:, :, y1:y2, x1:x2] += 1
|
307 |
+
assert (count_mat == 0).sum() == 0
|
308 |
+
seg_logits = preds / count_mat
|
309 |
+
|
310 |
+
return seg_logits
|
311 |
+
|
312 |
+
def whole_inference(self, inputs: Tensor, batch_img_metas: List[dict]) -> Tensor:
|
313 |
+
"""Inference with full image.
|
314 |
+
|
315 |
+
Args:
|
316 |
+
inputs (Tensor): The tensor should have a shape NxCxHxW, which
|
317 |
+
contains all images in the batch.
|
318 |
+
batch_img_metas (List[dict]): List of image metainfo where each may
|
319 |
+
also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
|
320 |
+
'ori_shape', and 'pad_shape'.
|
321 |
+
For details on the values of these keys see
|
322 |
+
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
|
323 |
+
|
324 |
+
Returns:
|
325 |
+
Tensor: The segmentation results, seg_logits from model of each
|
326 |
+
input image.
|
327 |
+
"""
|
328 |
+
|
329 |
+
seg_logits = self.encode_decode(inputs, batch_img_metas)
|
330 |
+
|
331 |
+
return seg_logits
|
332 |
+
|
333 |
+
def inference(self, inputs: Tensor, batch_img_metas: List[dict]) -> Tensor:
|
334 |
+
"""Inference with slide/whole style.
|
335 |
+
|
336 |
+
Args:
|
337 |
+
inputs (Tensor): The input image of shape (N, 3, H, W).
|
338 |
+
batch_img_metas (List[dict]): List of image metainfo where each may
|
339 |
+
also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
|
340 |
+
'ori_shape', 'pad_shape', and 'padding_size'.
|
341 |
+
For details on the values of these keys see
|
342 |
+
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
|
343 |
+
|
344 |
+
Returns:
|
345 |
+
Tensor: The segmentation results, seg_logits from model of each
|
346 |
+
input image.
|
347 |
+
"""
|
348 |
+
assert self.test_cfg.get("mode", "whole") in ["slide", "whole"], (
|
349 |
+
f'Only "slide" or "whole" test mode are supported, but got '
|
350 |
+
f'{self.test_cfg["mode"]}.'
|
351 |
+
)
|
352 |
+
ori_shape = batch_img_metas[0]["ori_shape"]
|
353 |
+
if not all(_["ori_shape"] == ori_shape for _ in batch_img_metas):
|
354 |
+
print_log(
|
355 |
+
"Image shapes are different in the batch.",
|
356 |
+
logger="current",
|
357 |
+
level=logging.WARN,
|
358 |
+
)
|
359 |
+
if self.test_cfg.mode == "slide":
|
360 |
+
seg_logit = self.slide_inference(inputs, batch_img_metas)
|
361 |
+
else:
|
362 |
+
seg_logit = self.whole_inference(inputs, batch_img_metas)
|
363 |
+
|
364 |
+
return seg_logit
|
365 |
+
|
366 |
+
def aug_test(self, inputs, batch_img_metas, rescale=True):
|
367 |
+
"""Test with augmentations.
|
368 |
+
|
369 |
+
Only rescale=True is supported.
|
370 |
+
"""
|
371 |
+
# aug_test rescale all imgs back to ori_shape for now
|
372 |
+
assert rescale
|
373 |
+
# to save memory, we get augmented seg logit inplace
|
374 |
+
seg_logit = self.inference(inputs[0], batch_img_metas[0], rescale)
|
375 |
+
for i in range(1, len(inputs)):
|
376 |
+
cur_seg_logit = self.inference(inputs[i], batch_img_metas[i], rescale)
|
377 |
+
seg_logit += cur_seg_logit
|
378 |
+
seg_logit /= len(inputs)
|
379 |
+
seg_pred = seg_logit.argmax(dim=1)
|
380 |
+
# unravel batch dim
|
381 |
+
seg_pred = list(seg_pred)
|
382 |
+
return seg_pred
|
requirements/docs.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
docutils==0.16.0
|
2 |
+
myst-parser
|
3 |
+
-e git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
|
4 |
+
sphinx==4.0.2
|
5 |
+
sphinx_copybutton
|
6 |
+
sphinx_markdown_tables
|
7 |
+
urllib3<2.0.0
|
requirements/optional.txt
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cityscapesscripts
|
2 |
+
-e git+https://github.com/openai/CLIP.git@main#egg=clip
|
3 |
+
|
4 |
+
# for vpd model
|
5 |
+
diffusers
|
6 |
+
einops==0.3.0
|
7 |
+
imageio==2.9.0
|
8 |
+
imageio-ffmpeg==0.4.2
|
9 |
+
invisible-watermark
|
10 |
+
kornia==0.6
|
11 |
+
-e git+https://github.com/CompVis/stable-diffusion@21f890f#egg=latent-diffusion
|
12 |
+
nibabel
|
13 |
+
omegaconf==2.1.1
|
14 |
+
pudb==2019.2
|
15 |
+
pytorch-lightning==1.4.2
|
16 |
+
streamlit>=0.73.1
|
17 |
+
-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
18 |
+
test-tube>=0.7.5
|
19 |
+
timm
|
20 |
+
torch-fidelity==0.3.0
|
21 |
+
torchmetrics==0.6.0
|
22 |
+
transformers==4.19.2
|
requirements/runtime.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
matplotlib
|
2 |
+
numpy
|
3 |
+
packaging
|
4 |
+
prettytable
|
5 |
+
scipy
|
tools/analysis_tools/analyze_logs.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
"""Modified from https://github.com/open-
|
3 |
+
mmlab/mmdetection/blob/master/tools/analysis_tools/analyze_logs.py."""
|
4 |
+
import argparse
|
5 |
+
import json
|
6 |
+
from collections import defaultdict
|
7 |
+
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import seaborn as sns
|
10 |
+
|
11 |
+
|
12 |
+
def plot_curve(log_dicts, args):
|
13 |
+
if args.backend is not None:
|
14 |
+
plt.switch_backend(args.backend)
|
15 |
+
sns.set_style(args.style)
|
16 |
+
# if legend is None, use {filename}_{key} as legend
|
17 |
+
legend = args.legend
|
18 |
+
if legend is None:
|
19 |
+
legend = []
|
20 |
+
for json_log in args.json_logs:
|
21 |
+
for metric in args.keys:
|
22 |
+
legend.append(f'{json_log}_{metric}')
|
23 |
+
assert len(legend) == (len(args.json_logs) * len(args.keys))
|
24 |
+
metrics = args.keys
|
25 |
+
|
26 |
+
num_metrics = len(metrics)
|
27 |
+
for i, log_dict in enumerate(log_dicts):
|
28 |
+
epochs = list(log_dict.keys())
|
29 |
+
for j, metric in enumerate(metrics):
|
30 |
+
print(f'plot curve of {args.json_logs[i]}, metric is {metric}')
|
31 |
+
plot_epochs = []
|
32 |
+
plot_iters = []
|
33 |
+
plot_values = []
|
34 |
+
# In some log files exist lines of validation,
|
35 |
+
# `mode` list is used to only collect iter number
|
36 |
+
# of training line.
|
37 |
+
for epoch in epochs:
|
38 |
+
epoch_logs = log_dict[epoch]
|
39 |
+
if metric not in epoch_logs.keys():
|
40 |
+
continue
|
41 |
+
if metric in ['mIoU', 'mAcc', 'aAcc']:
|
42 |
+
plot_epochs.append(epoch)
|
43 |
+
plot_values.append(epoch_logs[metric][0])
|
44 |
+
else:
|
45 |
+
for idx in range(len(epoch_logs[metric])):
|
46 |
+
plot_iters.append(epoch_logs['step'][idx])
|
47 |
+
plot_values.append(epoch_logs[metric][idx])
|
48 |
+
ax = plt.gca()
|
49 |
+
label = legend[i * num_metrics + j]
|
50 |
+
if metric in ['mIoU', 'mAcc', 'aAcc']:
|
51 |
+
ax.set_xticks(plot_epochs)
|
52 |
+
plt.xlabel('step')
|
53 |
+
plt.plot(plot_epochs, plot_values, label=label, marker='o')
|
54 |
+
else:
|
55 |
+
plt.xlabel('iter')
|
56 |
+
plt.plot(plot_iters, plot_values, label=label, linewidth=0.5)
|
57 |
+
plt.legend()
|
58 |
+
if args.title is not None:
|
59 |
+
plt.title(args.title)
|
60 |
+
if args.out is None:
|
61 |
+
plt.show()
|
62 |
+
else:
|
63 |
+
print(f'save curve to: {args.out}')
|
64 |
+
plt.savefig(args.out)
|
65 |
+
plt.cla()
|
66 |
+
|
67 |
+
|
68 |
+
def parse_args():
|
69 |
+
parser = argparse.ArgumentParser(description='Analyze Json Log')
|
70 |
+
parser.add_argument(
|
71 |
+
'json_logs',
|
72 |
+
type=str,
|
73 |
+
nargs='+',
|
74 |
+
help='path of train log in json format')
|
75 |
+
parser.add_argument(
|
76 |
+
'--keys',
|
77 |
+
type=str,
|
78 |
+
nargs='+',
|
79 |
+
default=['mIoU'],
|
80 |
+
help='the metric that you want to plot')
|
81 |
+
parser.add_argument('--title', type=str, help='title of figure')
|
82 |
+
parser.add_argument(
|
83 |
+
'--legend',
|
84 |
+
type=str,
|
85 |
+
nargs='+',
|
86 |
+
default=None,
|
87 |
+
help='legend of each plot')
|
88 |
+
parser.add_argument(
|
89 |
+
'--backend', type=str, default=None, help='backend of plt')
|
90 |
+
parser.add_argument(
|
91 |
+
'--style', type=str, default='dark', help='style of plt')
|
92 |
+
parser.add_argument('--out', type=str, default=None)
|
93 |
+
args = parser.parse_args()
|
94 |
+
return args
|
95 |
+
|
96 |
+
|
97 |
+
def load_json_logs(json_logs):
|
98 |
+
# load and convert json_logs to log_dict, key is step, value is a sub dict
|
99 |
+
# keys of sub dict is different metrics
|
100 |
+
# value of sub dict is a list of corresponding values of all iterations
|
101 |
+
log_dicts = [dict() for _ in json_logs]
|
102 |
+
prev_step = 0
|
103 |
+
for json_log, log_dict in zip(json_logs, log_dicts):
|
104 |
+
with open(json_log) as log_file:
|
105 |
+
for line in log_file:
|
106 |
+
log = json.loads(line.strip())
|
107 |
+
# the final step in json file is 0.
|
108 |
+
if 'step' in log and log['step'] != 0:
|
109 |
+
step = log['step']
|
110 |
+
prev_step = step
|
111 |
+
else:
|
112 |
+
step = prev_step
|
113 |
+
if step not in log_dict:
|
114 |
+
log_dict[step] = defaultdict(list)
|
115 |
+
for k, v in log.items():
|
116 |
+
log_dict[step][k].append(v)
|
117 |
+
return log_dicts
|
118 |
+
|
119 |
+
|
120 |
+
def main():
|
121 |
+
args = parse_args()
|
122 |
+
json_logs = args.json_logs
|
123 |
+
for json_log in json_logs:
|
124 |
+
assert json_log.endswith('.json')
|
125 |
+
log_dicts = load_json_logs(json_logs)
|
126 |
+
plot_curve(log_dicts, args)
|
127 |
+
|
128 |
+
|
129 |
+
if __name__ == '__main__':
|
130 |
+
main()
|
tools/analysis_tools/benchmark.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import argparse
|
3 |
+
import os.path as osp
|
4 |
+
import time
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from mmengine import Config
|
9 |
+
from mmengine.fileio import dump
|
10 |
+
from mmengine.model.utils import revert_sync_batchnorm
|
11 |
+
from mmengine.registry import init_default_scope
|
12 |
+
from mmengine.runner import Runner, load_checkpoint
|
13 |
+
from mmengine.utils import mkdir_or_exist
|
14 |
+
|
15 |
+
from mmseg.registry import MODELS
|
16 |
+
|
17 |
+
|
18 |
+
def parse_args():
|
19 |
+
parser = argparse.ArgumentParser(description='MMSeg benchmark a model')
|
20 |
+
parser.add_argument('config', help='test config file path')
|
21 |
+
parser.add_argument('checkpoint', help='checkpoint file')
|
22 |
+
parser.add_argument(
|
23 |
+
'--log-interval', type=int, default=50, help='interval of logging')
|
24 |
+
parser.add_argument(
|
25 |
+
'--work-dir',
|
26 |
+
help=('if specified, the results will be dumped '
|
27 |
+
'into the directory as json'))
|
28 |
+
parser.add_argument('--repeat-times', type=int, default=1)
|
29 |
+
args = parser.parse_args()
|
30 |
+
return args
|
31 |
+
|
32 |
+
|
33 |
+
def main():
|
34 |
+
args = parse_args()
|
35 |
+
cfg = Config.fromfile(args.config)
|
36 |
+
|
37 |
+
init_default_scope(cfg.get('default_scope', 'mmseg'))
|
38 |
+
|
39 |
+
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
40 |
+
if args.work_dir is not None:
|
41 |
+
mkdir_or_exist(osp.abspath(args.work_dir))
|
42 |
+
json_file = osp.join(args.work_dir, f'fps_{timestamp}.json')
|
43 |
+
else:
|
44 |
+
# use config filename as default work_dir if cfg.work_dir is None
|
45 |
+
work_dir = osp.join('./work_dirs',
|
46 |
+
osp.splitext(osp.basename(args.config))[0])
|
47 |
+
mkdir_or_exist(osp.abspath(work_dir))
|
48 |
+
json_file = osp.join(work_dir, f'fps_{timestamp}.json')
|
49 |
+
|
50 |
+
repeat_times = args.repeat_times
|
51 |
+
# set cudnn_benchmark
|
52 |
+
torch.backends.cudnn.benchmark = False
|
53 |
+
cfg.model.pretrained = None
|
54 |
+
|
55 |
+
benchmark_dict = dict(config=args.config, unit='img / s')
|
56 |
+
overall_fps_list = []
|
57 |
+
cfg.test_dataloader.batch_size = 1
|
58 |
+
for time_index in range(repeat_times):
|
59 |
+
print(f'Run {time_index + 1}:')
|
60 |
+
# build the dataloader
|
61 |
+
data_loader = Runner.build_dataloader(cfg.test_dataloader)
|
62 |
+
|
63 |
+
# build the model and load checkpoint
|
64 |
+
cfg.model.train_cfg = None
|
65 |
+
model = MODELS.build(cfg.model)
|
66 |
+
|
67 |
+
if 'checkpoint' in args and osp.exists(args.checkpoint):
|
68 |
+
load_checkpoint(model, args.checkpoint, map_location='cpu')
|
69 |
+
|
70 |
+
if torch.cuda.is_available():
|
71 |
+
model = model.cuda()
|
72 |
+
|
73 |
+
model = revert_sync_batchnorm(model)
|
74 |
+
|
75 |
+
model.eval()
|
76 |
+
|
77 |
+
# the first several iterations may be very slow so skip them
|
78 |
+
num_warmup = 5
|
79 |
+
pure_inf_time = 0
|
80 |
+
total_iters = 200
|
81 |
+
|
82 |
+
# benchmark with 200 batches and take the average
|
83 |
+
for i, data in enumerate(data_loader):
|
84 |
+
data = model.data_preprocessor(data, True)
|
85 |
+
inputs = data['inputs']
|
86 |
+
data_samples = data['data_samples']
|
87 |
+
if torch.cuda.is_available():
|
88 |
+
torch.cuda.synchronize()
|
89 |
+
start_time = time.perf_counter()
|
90 |
+
|
91 |
+
with torch.no_grad():
|
92 |
+
model(inputs, data_samples, mode='predict')
|
93 |
+
|
94 |
+
if torch.cuda.is_available():
|
95 |
+
torch.cuda.synchronize()
|
96 |
+
elapsed = time.perf_counter() - start_time
|
97 |
+
|
98 |
+
if i >= num_warmup:
|
99 |
+
pure_inf_time += elapsed
|
100 |
+
if (i + 1) % args.log_interval == 0:
|
101 |
+
fps = (i + 1 - num_warmup) / pure_inf_time
|
102 |
+
print(f'Done image [{i + 1:<3}/ {total_iters}], '
|
103 |
+
f'fps: {fps:.2f} img / s')
|
104 |
+
|
105 |
+
if (i + 1) == total_iters:
|
106 |
+
fps = (i + 1 - num_warmup) / pure_inf_time
|
107 |
+
print(f'Overall fps: {fps:.2f} img / s\n')
|
108 |
+
benchmark_dict[f'overall_fps_{time_index + 1}'] = round(fps, 2)
|
109 |
+
overall_fps_list.append(fps)
|
110 |
+
break
|
111 |
+
benchmark_dict['average_fps'] = round(np.mean(overall_fps_list), 2)
|
112 |
+
benchmark_dict['fps_variance'] = round(np.var(overall_fps_list), 4)
|
113 |
+
print(f'Average fps of {repeat_times} evaluations: '
|
114 |
+
f'{benchmark_dict["average_fps"]}')
|
115 |
+
print(f'The variance of {repeat_times} evaluations: '
|
116 |
+
f'{benchmark_dict["fps_variance"]}')
|
117 |
+
dump(benchmark_dict, json_file, indent=4)
|
118 |
+
|
119 |
+
|
120 |
+
if __name__ == '__main__':
|
121 |
+
main()
|
tools/analysis_tools/confusion_matrix.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import numpy as np
|
7 |
+
from matplotlib.ticker import MultipleLocator
|
8 |
+
from mmengine.config import Config, DictAction
|
9 |
+
from mmengine.registry import init_default_scope
|
10 |
+
from mmengine.utils import mkdir_or_exist, progressbar
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
from mmseg.registry import DATASETS
|
14 |
+
|
15 |
+
init_default_scope('mmseg')
|
16 |
+
|
17 |
+
|
18 |
+
def parse_args():
|
19 |
+
parser = argparse.ArgumentParser(
|
20 |
+
description='Generate confusion matrix from segmentation results')
|
21 |
+
parser.add_argument('config', help='test config file path')
|
22 |
+
parser.add_argument(
|
23 |
+
'prediction_path', help='prediction path where test folder result')
|
24 |
+
parser.add_argument(
|
25 |
+
'save_dir', help='directory where confusion matrix will be saved')
|
26 |
+
parser.add_argument(
|
27 |
+
'--show', action='store_true', help='show confusion matrix')
|
28 |
+
parser.add_argument(
|
29 |
+
'--color-theme',
|
30 |
+
default='winter',
|
31 |
+
help='theme of the matrix color map')
|
32 |
+
parser.add_argument(
|
33 |
+
'--title',
|
34 |
+
default='Normalized Confusion Matrix',
|
35 |
+
help='title of the matrix color map')
|
36 |
+
parser.add_argument(
|
37 |
+
'--cfg-options',
|
38 |
+
nargs='+',
|
39 |
+
action=DictAction,
|
40 |
+
help='override some settings in the used config, the key-value pair '
|
41 |
+
'in xxx=yyy format will be merged into config file. If the value to '
|
42 |
+
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
43 |
+
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
44 |
+
'Note that the quotation marks are necessary and that no white space '
|
45 |
+
'is allowed.')
|
46 |
+
args = parser.parse_args()
|
47 |
+
return args
|
48 |
+
|
49 |
+
|
50 |
+
def calculate_confusion_matrix(dataset, results):
|
51 |
+
"""Calculate the confusion matrix.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
dataset (Dataset): Test or val dataset.
|
55 |
+
results (list[ndarray]): A list of segmentation results in each image.
|
56 |
+
"""
|
57 |
+
n = len(dataset.METAINFO['classes'])
|
58 |
+
confusion_matrix = np.zeros(shape=[n, n])
|
59 |
+
assert len(dataset) == len(results)
|
60 |
+
ignore_index = dataset.ignore_index
|
61 |
+
reduce_zero_label = dataset.reduce_zero_label
|
62 |
+
prog_bar = progressbar.ProgressBar(len(results))
|
63 |
+
for idx, per_img_res in enumerate(results):
|
64 |
+
res_segm = per_img_res
|
65 |
+
gt_segm = dataset[idx]['data_samples'] \
|
66 |
+
.gt_sem_seg.data.squeeze().numpy().astype(np.uint8)
|
67 |
+
gt_segm, res_segm = gt_segm.flatten(), res_segm.flatten()
|
68 |
+
if reduce_zero_label:
|
69 |
+
gt_segm = gt_segm - 1
|
70 |
+
to_ignore = gt_segm == ignore_index
|
71 |
+
|
72 |
+
gt_segm, res_segm = gt_segm[~to_ignore], res_segm[~to_ignore]
|
73 |
+
inds = n * gt_segm + res_segm
|
74 |
+
mat = np.bincount(inds, minlength=n**2).reshape(n, n)
|
75 |
+
confusion_matrix += mat
|
76 |
+
prog_bar.update()
|
77 |
+
return confusion_matrix
|
78 |
+
|
79 |
+
|
80 |
+
def plot_confusion_matrix(confusion_matrix,
|
81 |
+
labels,
|
82 |
+
save_dir=None,
|
83 |
+
show=True,
|
84 |
+
title='Normalized Confusion Matrix',
|
85 |
+
color_theme='OrRd'):
|
86 |
+
"""Draw confusion matrix with matplotlib.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
confusion_matrix (ndarray): The confusion matrix.
|
90 |
+
labels (list[str]): List of class names.
|
91 |
+
save_dir (str|optional): If set, save the confusion matrix plot to the
|
92 |
+
given path. Default: None.
|
93 |
+
show (bool): Whether to show the plot. Default: True.
|
94 |
+
title (str): Title of the plot. Default: `Normalized Confusion Matrix`.
|
95 |
+
color_theme (str): Theme of the matrix color map. Default: `winter`.
|
96 |
+
"""
|
97 |
+
# normalize the confusion matrix
|
98 |
+
per_label_sums = confusion_matrix.sum(axis=1)[:, np.newaxis]
|
99 |
+
confusion_matrix = \
|
100 |
+
confusion_matrix.astype(np.float32) / per_label_sums * 100
|
101 |
+
|
102 |
+
num_classes = len(labels)
|
103 |
+
fig, ax = plt.subplots(
|
104 |
+
figsize=(2 * num_classes, 2 * num_classes * 0.8), dpi=300)
|
105 |
+
cmap = plt.get_cmap(color_theme)
|
106 |
+
im = ax.imshow(confusion_matrix, cmap=cmap)
|
107 |
+
colorbar = plt.colorbar(mappable=im, ax=ax)
|
108 |
+
colorbar.ax.tick_params(labelsize=20) # 设置 colorbar 标签的字体大小
|
109 |
+
|
110 |
+
title_font = {'weight': 'bold', 'size': 20}
|
111 |
+
ax.set_title(title, fontdict=title_font)
|
112 |
+
label_font = {'size': 40}
|
113 |
+
plt.ylabel('Ground Truth Label', fontdict=label_font)
|
114 |
+
plt.xlabel('Prediction Label', fontdict=label_font)
|
115 |
+
|
116 |
+
# draw locator
|
117 |
+
xmajor_locator = MultipleLocator(1)
|
118 |
+
xminor_locator = MultipleLocator(0.5)
|
119 |
+
ax.xaxis.set_major_locator(xmajor_locator)
|
120 |
+
ax.xaxis.set_minor_locator(xminor_locator)
|
121 |
+
ymajor_locator = MultipleLocator(1)
|
122 |
+
yminor_locator = MultipleLocator(0.5)
|
123 |
+
ax.yaxis.set_major_locator(ymajor_locator)
|
124 |
+
ax.yaxis.set_minor_locator(yminor_locator)
|
125 |
+
|
126 |
+
# draw grid
|
127 |
+
ax.grid(True, which='minor', linestyle='-')
|
128 |
+
|
129 |
+
# draw label
|
130 |
+
ax.set_xticks(np.arange(num_classes))
|
131 |
+
ax.set_yticks(np.arange(num_classes))
|
132 |
+
ax.set_xticklabels(labels, fontsize=20)
|
133 |
+
ax.set_yticklabels(labels, fontsize=20)
|
134 |
+
|
135 |
+
ax.tick_params(
|
136 |
+
axis='x', bottom=False, top=True, labelbottom=False, labeltop=True)
|
137 |
+
plt.setp(
|
138 |
+
ax.get_xticklabels(), rotation=45, ha='left', rotation_mode='anchor')
|
139 |
+
|
140 |
+
# draw confusion matrix value
|
141 |
+
for i in range(num_classes):
|
142 |
+
for j in range(num_classes):
|
143 |
+
ax.text(
|
144 |
+
j,
|
145 |
+
i,
|
146 |
+
'{}%'.format(
|
147 |
+
round(confusion_matrix[i, j], 2
|
148 |
+
) if not np.isnan(confusion_matrix[i, j]) else -1),
|
149 |
+
ha='center',
|
150 |
+
va='center',
|
151 |
+
color='k',
|
152 |
+
size=20)
|
153 |
+
|
154 |
+
ax.set_ylim(len(confusion_matrix) - 0.5, -0.5) # matplotlib>3.1.1
|
155 |
+
|
156 |
+
fig.tight_layout()
|
157 |
+
if save_dir is not None:
|
158 |
+
mkdir_or_exist(save_dir)
|
159 |
+
plt.savefig(
|
160 |
+
os.path.join(save_dir, 'confusion_matrix.png'), format='png')
|
161 |
+
if show:
|
162 |
+
plt.show()
|
163 |
+
|
164 |
+
|
165 |
+
def main():
|
166 |
+
args = parse_args()
|
167 |
+
|
168 |
+
cfg = Config.fromfile(args.config)
|
169 |
+
if args.cfg_options is not None:
|
170 |
+
cfg.merge_from_dict(args.cfg_options)
|
171 |
+
|
172 |
+
results = []
|
173 |
+
for img in sorted(os.listdir(args.prediction_path)):
|
174 |
+
img = os.path.join(args.prediction_path, img)
|
175 |
+
image = Image.open(img)
|
176 |
+
image = np.copy(image)
|
177 |
+
results.append(image)
|
178 |
+
|
179 |
+
assert isinstance(results, list)
|
180 |
+
if isinstance(results[0], np.ndarray):
|
181 |
+
pass
|
182 |
+
else:
|
183 |
+
raise TypeError('invalid type of prediction results')
|
184 |
+
|
185 |
+
dataset = DATASETS.build(cfg.test_dataloader.dataset)
|
186 |
+
confusion_matrix = calculate_confusion_matrix(dataset, results)
|
187 |
+
plot_confusion_matrix(
|
188 |
+
confusion_matrix,
|
189 |
+
dataset.METAINFO['classes'],
|
190 |
+
save_dir=args.save_dir,
|
191 |
+
show=args.show,
|
192 |
+
title=args.title,
|
193 |
+
color_theme=args.color_theme)
|
194 |
+
|
195 |
+
|
196 |
+
if __name__ == '__main__':
|
197 |
+
main()
|
tools/analysis_tools/get_flops.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import argparse
|
3 |
+
import tempfile
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from mmengine import Config, DictAction
|
8 |
+
from mmengine.logging import MMLogger
|
9 |
+
from mmengine.model import revert_sync_batchnorm
|
10 |
+
from mmengine.registry import init_default_scope
|
11 |
+
|
12 |
+
from mmseg.models import BaseSegmentor
|
13 |
+
from mmseg.registry import MODELS
|
14 |
+
from mmseg.structures import SegDataSample
|
15 |
+
from vegseg import models
|
16 |
+
try:
|
17 |
+
from mmengine.analysis import get_model_complexity_info
|
18 |
+
from mmengine.analysis.print_helper import _format_size
|
19 |
+
except ImportError:
|
20 |
+
raise ImportError('Please upgrade mmengine >= 0.6.0 to use this script.')
|
21 |
+
|
22 |
+
|
23 |
+
def parse_args():
|
24 |
+
parser = argparse.ArgumentParser(
|
25 |
+
description='Get the FLOPs of a segmentor')
|
26 |
+
parser.add_argument('config', help='train config file path')
|
27 |
+
parser.add_argument(
|
28 |
+
'--shape',
|
29 |
+
type=int,
|
30 |
+
nargs='+',
|
31 |
+
default=[2048, 1024],
|
32 |
+
help='input image size')
|
33 |
+
parser.add_argument(
|
34 |
+
'--cfg-options',
|
35 |
+
nargs='+',
|
36 |
+
action=DictAction,
|
37 |
+
help='override some settings in the used config, the key-value pair '
|
38 |
+
'in xxx=yyy format will be merged into config file. If the value to '
|
39 |
+
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
40 |
+
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
41 |
+
'Note that the quotation marks are necessary and that no white space '
|
42 |
+
'is allowed.')
|
43 |
+
args = parser.parse_args()
|
44 |
+
return args
|
45 |
+
|
46 |
+
|
47 |
+
def inference(args: argparse.Namespace, logger: MMLogger) -> dict:
|
48 |
+
config_name = Path(args.config)
|
49 |
+
|
50 |
+
if not config_name.exists():
|
51 |
+
logger.error(f'Config file {config_name} does not exist')
|
52 |
+
|
53 |
+
cfg: Config = Config.fromfile(config_name)
|
54 |
+
cfg.work_dir = tempfile.TemporaryDirectory().name
|
55 |
+
cfg.log_level = 'WARN'
|
56 |
+
if args.cfg_options is not None:
|
57 |
+
cfg.merge_from_dict(args.cfg_options)
|
58 |
+
|
59 |
+
init_default_scope(cfg.get('scope', 'mmseg'))
|
60 |
+
|
61 |
+
if len(args.shape) == 1:
|
62 |
+
input_shape = (3, args.shape[0], args.shape[0])
|
63 |
+
elif len(args.shape) == 2:
|
64 |
+
input_shape = (3, ) + tuple(args.shape)
|
65 |
+
else:
|
66 |
+
raise ValueError('invalid input shape')
|
67 |
+
result = {}
|
68 |
+
|
69 |
+
model: BaseSegmentor = MODELS.build(cfg.model)
|
70 |
+
if hasattr(model, 'auxiliary_head'):
|
71 |
+
model.auxiliary_head = None
|
72 |
+
if hasattr(model, 'teach_backbone'):
|
73 |
+
model.teach_backbone = None
|
74 |
+
if torch.cuda.is_available():
|
75 |
+
model.cuda()
|
76 |
+
model = revert_sync_batchnorm(model)
|
77 |
+
result['ori_shape'] = input_shape[-2:]
|
78 |
+
result['pad_shape'] = input_shape[-2:]
|
79 |
+
data_batch = {
|
80 |
+
'inputs': [torch.rand(input_shape)],
|
81 |
+
'data_samples': [SegDataSample(metainfo=result)]
|
82 |
+
}
|
83 |
+
data = model.data_preprocessor(data_batch)
|
84 |
+
model.eval()
|
85 |
+
if cfg.model.decode_head.type in ['MaskFormerHead', 'Mask2FormerHead']:
|
86 |
+
# TODO: Support MaskFormer and Mask2Former
|
87 |
+
raise NotImplementedError('MaskFormer and Mask2Former are not '
|
88 |
+
'supported yet.')
|
89 |
+
outputs = get_model_complexity_info(
|
90 |
+
model,
|
91 |
+
input_shape=None,
|
92 |
+
inputs=data['inputs'],
|
93 |
+
show_table=False,
|
94 |
+
show_arch=False)
|
95 |
+
result['flops'] = _format_size(outputs['flops'])
|
96 |
+
result['params'] = _format_size(outputs['params'])
|
97 |
+
result['compute_type'] = 'direct: randomly generate a picture'
|
98 |
+
return result
|
99 |
+
|
100 |
+
|
101 |
+
def main():
|
102 |
+
|
103 |
+
args = parse_args()
|
104 |
+
logger = MMLogger.get_instance(name='MMLogger')
|
105 |
+
|
106 |
+
result = inference(args, logger)
|
107 |
+
split_line = '=' * 30
|
108 |
+
ori_shape = result['ori_shape']
|
109 |
+
pad_shape = result['pad_shape']
|
110 |
+
flops = result['flops']
|
111 |
+
params = result['params']
|
112 |
+
compute_type = result['compute_type']
|
113 |
+
|
114 |
+
if pad_shape != ori_shape:
|
115 |
+
print(f'{split_line}\nUse size divisor set input shape '
|
116 |
+
f'from {ori_shape} to {pad_shape}')
|
117 |
+
print(f'{split_line}\nCompute type: {compute_type}\n'
|
118 |
+
f'Input shape: {pad_shape}\nFlops: {flops}\n'
|
119 |
+
f'Params: {params}\n{split_line}')
|
120 |
+
print('!!!Please be cautious if you use the results in papers. '
|
121 |
+
'You may need to check if all ops are supported and verify '
|
122 |
+
'that the flops computation is correct.')
|
123 |
+
|
124 |
+
|
125 |
+
if __name__ == '__main__':
|
126 |
+
main()
|
tools/analysis_tools/visualization_cam.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
"""Use the pytorch-grad-cam tool to visualize Class Activation Maps (CAM).
|
3 |
+
|
4 |
+
requirement: pip install grad-cam
|
5 |
+
"""
|
6 |
+
|
7 |
+
from argparse import ArgumentParser
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from mmengine import Config
|
13 |
+
from mmengine.model import revert_sync_batchnorm
|
14 |
+
from PIL import Image
|
15 |
+
from pytorch_grad_cam import GradCAM
|
16 |
+
from pytorch_grad_cam.utils.image import preprocess_image, show_cam_on_image
|
17 |
+
|
18 |
+
from mmseg.apis import inference_model, init_model, show_result_pyplot
|
19 |
+
from mmseg.utils import register_all_modules
|
20 |
+
|
21 |
+
|
22 |
+
class SemanticSegmentationTarget:
|
23 |
+
"""wrap the model.
|
24 |
+
|
25 |
+
requirement: pip install grad-cam
|
26 |
+
|
27 |
+
Args:
|
28 |
+
category (int): Visualization class.
|
29 |
+
mask (ndarray): Mask of class.
|
30 |
+
size (tuple): Image size.
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(self, category, mask, size):
|
34 |
+
self.category = category
|
35 |
+
self.mask = torch.from_numpy(mask)
|
36 |
+
self.size = size
|
37 |
+
if torch.cuda.is_available():
|
38 |
+
self.mask = self.mask.cuda()
|
39 |
+
|
40 |
+
def __call__(self, model_output):
|
41 |
+
model_output = torch.unsqueeze(model_output, dim=0)
|
42 |
+
model_output = F.interpolate(
|
43 |
+
model_output, size=self.size, mode='bilinear')
|
44 |
+
model_output = torch.squeeze(model_output, dim=0)
|
45 |
+
|
46 |
+
return (model_output[self.category, :, :] * self.mask).sum()
|
47 |
+
|
48 |
+
|
49 |
+
def main():
|
50 |
+
parser = ArgumentParser()
|
51 |
+
parser.add_argument('img', help='Image file')
|
52 |
+
parser.add_argument('config', help='Config file')
|
53 |
+
parser.add_argument('checkpoint', help='Checkpoint file')
|
54 |
+
parser.add_argument(
|
55 |
+
'--out-file',
|
56 |
+
default='prediction.png',
|
57 |
+
help='Path to output prediction file')
|
58 |
+
parser.add_argument(
|
59 |
+
'--cam-file', default='vis_cam.png', help='Path to output cam file')
|
60 |
+
parser.add_argument(
|
61 |
+
'--target-layers',
|
62 |
+
default='backbone.layer4[2]',
|
63 |
+
help='Target layers to visualize CAM')
|
64 |
+
parser.add_argument(
|
65 |
+
'--category-index', default='7', help='Category to visualize CAM')
|
66 |
+
parser.add_argument(
|
67 |
+
'--device', default='cuda:0', help='Device used for inference')
|
68 |
+
args = parser.parse_args()
|
69 |
+
|
70 |
+
# build the model from a config file and a checkpoint file
|
71 |
+
register_all_modules()
|
72 |
+
model = init_model(args.config, args.checkpoint, device=args.device)
|
73 |
+
if args.device == 'cpu':
|
74 |
+
model = revert_sync_batchnorm(model)
|
75 |
+
|
76 |
+
# test a single image
|
77 |
+
result = inference_model(model, args.img)
|
78 |
+
|
79 |
+
# show the results
|
80 |
+
show_result_pyplot(
|
81 |
+
model,
|
82 |
+
args.img,
|
83 |
+
result,
|
84 |
+
draw_gt=False,
|
85 |
+
show=False if args.out_file is not None else True,
|
86 |
+
out_file=args.out_file)
|
87 |
+
|
88 |
+
# result data conversion
|
89 |
+
prediction_data = result.pred_sem_seg.data
|
90 |
+
pre_np_data = prediction_data.cpu().numpy().squeeze(0)
|
91 |
+
|
92 |
+
target_layers = args.target_layers
|
93 |
+
target_layers = [eval(f'model.{target_layers}')]
|
94 |
+
|
95 |
+
category = int(args.category_index)
|
96 |
+
mask_float = np.float32(pre_np_data == category)
|
97 |
+
|
98 |
+
# data processing
|
99 |
+
image = np.array(Image.open(args.img).convert('RGB'))
|
100 |
+
height, width = image.shape[0], image.shape[1]
|
101 |
+
rgb_img = np.float32(image) / 255
|
102 |
+
config = Config.fromfile(args.config)
|
103 |
+
image_mean = config.data_preprocessor['mean']
|
104 |
+
image_std = config.data_preprocessor['std']
|
105 |
+
input_tensor = preprocess_image(
|
106 |
+
rgb_img,
|
107 |
+
mean=[x / 255 for x in image_mean],
|
108 |
+
std=[x / 255 for x in image_std])
|
109 |
+
|
110 |
+
# Grad CAM(Class Activation Maps)
|
111 |
+
# Can also be LayerCAM, XGradCAM, GradCAMPlusPlus, EigenCAM, EigenGradCAM
|
112 |
+
targets = [
|
113 |
+
SemanticSegmentationTarget(category, mask_float, (height, width))
|
114 |
+
]
|
115 |
+
with GradCAM(
|
116 |
+
model=model,
|
117 |
+
target_layers=target_layers,
|
118 |
+
use_cuda=torch.cuda.is_available()) as cam:
|
119 |
+
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0, :]
|
120 |
+
cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
|
121 |
+
|
122 |
+
# save cam file
|
123 |
+
Image.fromarray(cam_image).save(args.cam_file)
|
124 |
+
|
125 |
+
|
126 |
+
if __name__ == '__main__':
|
127 |
+
main()
|
tools/dataset_converters/chase_db1.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
import tempfile
|
6 |
+
import zipfile
|
7 |
+
|
8 |
+
import mmcv
|
9 |
+
from mmengine.utils import mkdir_or_exist
|
10 |
+
|
11 |
+
CHASE_DB1_LEN = 28 * 3
|
12 |
+
TRAINING_LEN = 60
|
13 |
+
|
14 |
+
|
15 |
+
def parse_args():
|
16 |
+
parser = argparse.ArgumentParser(
|
17 |
+
description='Convert CHASE_DB1 dataset to mmsegmentation format')
|
18 |
+
parser.add_argument('dataset_path', help='path of CHASEDB1.zip')
|
19 |
+
parser.add_argument('--tmp_dir', help='path of the temporary directory')
|
20 |
+
parser.add_argument('-o', '--out_dir', help='output path')
|
21 |
+
args = parser.parse_args()
|
22 |
+
return args
|
23 |
+
|
24 |
+
|
25 |
+
def main():
|
26 |
+
args = parse_args()
|
27 |
+
dataset_path = args.dataset_path
|
28 |
+
if args.out_dir is None:
|
29 |
+
out_dir = osp.join('data', 'CHASE_DB1')
|
30 |
+
else:
|
31 |
+
out_dir = args.out_dir
|
32 |
+
|
33 |
+
print('Making directories...')
|
34 |
+
mkdir_or_exist(out_dir)
|
35 |
+
mkdir_or_exist(osp.join(out_dir, 'images'))
|
36 |
+
mkdir_or_exist(osp.join(out_dir, 'images', 'training'))
|
37 |
+
mkdir_or_exist(osp.join(out_dir, 'images', 'validation'))
|
38 |
+
mkdir_or_exist(osp.join(out_dir, 'annotations'))
|
39 |
+
mkdir_or_exist(osp.join(out_dir, 'annotations', 'training'))
|
40 |
+
mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation'))
|
41 |
+
|
42 |
+
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
|
43 |
+
print('Extracting CHASEDB1.zip...')
|
44 |
+
zip_file = zipfile.ZipFile(dataset_path)
|
45 |
+
zip_file.extractall(tmp_dir)
|
46 |
+
|
47 |
+
print('Generating training dataset...')
|
48 |
+
|
49 |
+
assert len(os.listdir(tmp_dir)) == CHASE_DB1_LEN, \
|
50 |
+
f'len(os.listdir(tmp_dir)) != {CHASE_DB1_LEN}'
|
51 |
+
|
52 |
+
for img_name in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]:
|
53 |
+
img = mmcv.imread(osp.join(tmp_dir, img_name))
|
54 |
+
if osp.splitext(img_name)[1] == '.jpg':
|
55 |
+
mmcv.imwrite(
|
56 |
+
img,
|
57 |
+
osp.join(out_dir, 'images', 'training',
|
58 |
+
osp.splitext(img_name)[0] + '.png'))
|
59 |
+
else:
|
60 |
+
# The annotation img should be divided by 128, because some of
|
61 |
+
# the annotation imgs are not standard. We should set a
|
62 |
+
# threshold to convert the nonstandard annotation imgs. The
|
63 |
+
# value divided by 128 is equivalent to '1 if value >= 128
|
64 |
+
# else 0'
|
65 |
+
mmcv.imwrite(
|
66 |
+
img[:, :, 0] // 128,
|
67 |
+
osp.join(out_dir, 'annotations', 'training',
|
68 |
+
osp.splitext(img_name)[0] + '.png'))
|
69 |
+
|
70 |
+
for img_name in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]:
|
71 |
+
img = mmcv.imread(osp.join(tmp_dir, img_name))
|
72 |
+
if osp.splitext(img_name)[1] == '.jpg':
|
73 |
+
mmcv.imwrite(
|
74 |
+
img,
|
75 |
+
osp.join(out_dir, 'images', 'validation',
|
76 |
+
osp.splitext(img_name)[0] + '.png'))
|
77 |
+
else:
|
78 |
+
mmcv.imwrite(
|
79 |
+
img[:, :, 0] // 128,
|
80 |
+
osp.join(out_dir, 'annotations', 'validation',
|
81 |
+
osp.splitext(img_name)[0] + '.png'))
|
82 |
+
|
83 |
+
print('Removing the temporary files...')
|
84 |
+
|
85 |
+
print('Done!')
|
86 |
+
|
87 |
+
|
88 |
+
if __name__ == '__main__':
|
89 |
+
main()
|
tools/dataset_converters/cityscapes.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import argparse
|
3 |
+
import os.path as osp
|
4 |
+
|
5 |
+
from cityscapesscripts.preparation.json2labelImg import json2labelImg
|
6 |
+
from mmengine.utils import (mkdir_or_exist, scandir, track_parallel_progress,
|
7 |
+
track_progress)
|
8 |
+
|
9 |
+
|
10 |
+
def convert_json_to_label(json_file):
|
11 |
+
label_file = json_file.replace('_polygons.json', '_labelTrainIds.png')
|
12 |
+
json2labelImg(json_file, label_file, 'trainIds')
|
13 |
+
|
14 |
+
|
15 |
+
def parse_args():
|
16 |
+
parser = argparse.ArgumentParser(
|
17 |
+
description='Convert Cityscapes annotations to TrainIds')
|
18 |
+
parser.add_argument('cityscapes_path', help='cityscapes data path')
|
19 |
+
parser.add_argument('--gt-dir', default='gtFine', type=str)
|
20 |
+
parser.add_argument('-o', '--out-dir', help='output path')
|
21 |
+
parser.add_argument(
|
22 |
+
'--nproc', default=1, type=int, help='number of process')
|
23 |
+
args = parser.parse_args()
|
24 |
+
return args
|
25 |
+
|
26 |
+
|
27 |
+
def main():
|
28 |
+
args = parse_args()
|
29 |
+
cityscapes_path = args.cityscapes_path
|
30 |
+
out_dir = args.out_dir if args.out_dir else cityscapes_path
|
31 |
+
mkdir_or_exist(out_dir)
|
32 |
+
|
33 |
+
gt_dir = osp.join(cityscapes_path, args.gt_dir)
|
34 |
+
|
35 |
+
poly_files = []
|
36 |
+
for poly in scandir(gt_dir, '_polygons.json', recursive=True):
|
37 |
+
poly_file = osp.join(gt_dir, poly)
|
38 |
+
poly_files.append(poly_file)
|
39 |
+
if args.nproc > 1:
|
40 |
+
track_parallel_progress(convert_json_to_label, poly_files, args.nproc)
|
41 |
+
else:
|
42 |
+
track_progress(convert_json_to_label, poly_files)
|
43 |
+
|
44 |
+
split_names = ['train', 'val', 'test']
|
45 |
+
|
46 |
+
for split in split_names:
|
47 |
+
filenames = []
|
48 |
+
for poly in scandir(
|
49 |
+
osp.join(gt_dir, split), '_polygons.json', recursive=True):
|
50 |
+
filenames.append(poly.replace('_gtFine_polygons.json', ''))
|
51 |
+
with open(osp.join(out_dir, f'{split}.txt'), 'w') as f:
|
52 |
+
f.writelines(f + '\n' for f in filenames)
|
53 |
+
|
54 |
+
|
55 |
+
if __name__ == '__main__':
|
56 |
+
main()
|
tools/dataset_converters/coco_stuff10k.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import argparse
|
3 |
+
import os.path as osp
|
4 |
+
import shutil
|
5 |
+
from functools import partial
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
from mmengine.utils import (mkdir_or_exist, track_parallel_progress,
|
9 |
+
track_progress)
|
10 |
+
from PIL import Image
|
11 |
+
from scipy.io import loadmat
|
12 |
+
|
13 |
+
COCO_LEN = 10000
|
14 |
+
|
15 |
+
clsID_to_trID = {
|
16 |
+
0: 0,
|
17 |
+
1: 1,
|
18 |
+
2: 2,
|
19 |
+
3: 3,
|
20 |
+
4: 4,
|
21 |
+
5: 5,
|
22 |
+
6: 6,
|
23 |
+
7: 7,
|
24 |
+
8: 8,
|
25 |
+
9: 9,
|
26 |
+
10: 10,
|
27 |
+
11: 11,
|
28 |
+
13: 12,
|
29 |
+
14: 13,
|
30 |
+
15: 14,
|
31 |
+
16: 15,
|
32 |
+
17: 16,
|
33 |
+
18: 17,
|
34 |
+
19: 18,
|
35 |
+
20: 19,
|
36 |
+
21: 20,
|
37 |
+
22: 21,
|
38 |
+
23: 22,
|
39 |
+
24: 23,
|
40 |
+
25: 24,
|
41 |
+
27: 25,
|
42 |
+
28: 26,
|
43 |
+
31: 27,
|
44 |
+
32: 28,
|
45 |
+
33: 29,
|
46 |
+
34: 30,
|
47 |
+
35: 31,
|
48 |
+
36: 32,
|
49 |
+
37: 33,
|
50 |
+
38: 34,
|
51 |
+
39: 35,
|
52 |
+
40: 36,
|
53 |
+
41: 37,
|
54 |
+
42: 38,
|
55 |
+
43: 39,
|
56 |
+
44: 40,
|
57 |
+
46: 41,
|
58 |
+
47: 42,
|
59 |
+
48: 43,
|
60 |
+
49: 44,
|
61 |
+
50: 45,
|
62 |
+
51: 46,
|
63 |
+
52: 47,
|
64 |
+
53: 48,
|
65 |
+
54: 49,
|
66 |
+
55: 50,
|
67 |
+
56: 51,
|
68 |
+
57: 52,
|
69 |
+
58: 53,
|
70 |
+
59: 54,
|
71 |
+
60: 55,
|
72 |
+
61: 56,
|
73 |
+
62: 57,
|
74 |
+
63: 58,
|
75 |
+
64: 59,
|
76 |
+
65: 60,
|
77 |
+
67: 61,
|
78 |
+
70: 62,
|
79 |
+
72: 63,
|
80 |
+
73: 64,
|
81 |
+
74: 65,
|
82 |
+
75: 66,
|
83 |
+
76: 67,
|
84 |
+
77: 68,
|
85 |
+
78: 69,
|
86 |
+
79: 70,
|
87 |
+
80: 71,
|
88 |
+
81: 72,
|
89 |
+
82: 73,
|
90 |
+
84: 74,
|
91 |
+
85: 75,
|
92 |
+
86: 76,
|
93 |
+
87: 77,
|
94 |
+
88: 78,
|
95 |
+
89: 79,
|
96 |
+
90: 80,
|
97 |
+
92: 81,
|
98 |
+
93: 82,
|
99 |
+
94: 83,
|
100 |
+
95: 84,
|
101 |
+
96: 85,
|
102 |
+
97: 86,
|
103 |
+
98: 87,
|
104 |
+
99: 88,
|
105 |
+
100: 89,
|
106 |
+
101: 90,
|
107 |
+
102: 91,
|
108 |
+
103: 92,
|
109 |
+
104: 93,
|
110 |
+
105: 94,
|
111 |
+
106: 95,
|
112 |
+
107: 96,
|
113 |
+
108: 97,
|
114 |
+
109: 98,
|
115 |
+
110: 99,
|
116 |
+
111: 100,
|
117 |
+
112: 101,
|
118 |
+
113: 102,
|
119 |
+
114: 103,
|
120 |
+
115: 104,
|
121 |
+
116: 105,
|
122 |
+
117: 106,
|
123 |
+
118: 107,
|
124 |
+
119: 108,
|
125 |
+
120: 109,
|
126 |
+
121: 110,
|
127 |
+
122: 111,
|
128 |
+
123: 112,
|
129 |
+
124: 113,
|
130 |
+
125: 114,
|
131 |
+
126: 115,
|
132 |
+
127: 116,
|
133 |
+
128: 117,
|
134 |
+
129: 118,
|
135 |
+
130: 119,
|
136 |
+
131: 120,
|
137 |
+
132: 121,
|
138 |
+
133: 122,
|
139 |
+
134: 123,
|
140 |
+
135: 124,
|
141 |
+
136: 125,
|
142 |
+
137: 126,
|
143 |
+
138: 127,
|
144 |
+
139: 128,
|
145 |
+
140: 129,
|
146 |
+
141: 130,
|
147 |
+
142: 131,
|
148 |
+
143: 132,
|
149 |
+
144: 133,
|
150 |
+
145: 134,
|
151 |
+
146: 135,
|
152 |
+
147: 136,
|
153 |
+
148: 137,
|
154 |
+
149: 138,
|
155 |
+
150: 139,
|
156 |
+
151: 140,
|
157 |
+
152: 141,
|
158 |
+
153: 142,
|
159 |
+
154: 143,
|
160 |
+
155: 144,
|
161 |
+
156: 145,
|
162 |
+
157: 146,
|
163 |
+
158: 147,
|
164 |
+
159: 148,
|
165 |
+
160: 149,
|
166 |
+
161: 150,
|
167 |
+
162: 151,
|
168 |
+
163: 152,
|
169 |
+
164: 153,
|
170 |
+
165: 154,
|
171 |
+
166: 155,
|
172 |
+
167: 156,
|
173 |
+
168: 157,
|
174 |
+
169: 158,
|
175 |
+
170: 159,
|
176 |
+
171: 160,
|
177 |
+
172: 161,
|
178 |
+
173: 162,
|
179 |
+
174: 163,
|
180 |
+
175: 164,
|
181 |
+
176: 165,
|
182 |
+
177: 166,
|
183 |
+
178: 167,
|
184 |
+
179: 168,
|
185 |
+
180: 169,
|
186 |
+
181: 170,
|
187 |
+
182: 171
|
188 |
+
}
|
189 |
+
|
190 |
+
|
191 |
+
def convert_to_trainID(tuple_path, in_img_dir, in_ann_dir, out_img_dir,
|
192 |
+
out_mask_dir, is_train):
|
193 |
+
imgpath, maskpath = tuple_path
|
194 |
+
shutil.copyfile(
|
195 |
+
osp.join(in_img_dir, imgpath),
|
196 |
+
osp.join(out_img_dir, 'train2014', imgpath) if is_train else osp.join(
|
197 |
+
out_img_dir, 'test2014', imgpath))
|
198 |
+
annotate = loadmat(osp.join(in_ann_dir, maskpath))
|
199 |
+
mask = annotate['S'].astype(np.uint8)
|
200 |
+
mask_copy = mask.copy()
|
201 |
+
for clsID, trID in clsID_to_trID.items():
|
202 |
+
mask_copy[mask == clsID] = trID
|
203 |
+
seg_filename = osp.join(out_mask_dir, 'train2014',
|
204 |
+
maskpath.split('.')[0] +
|
205 |
+
'_labelTrainIds.png') if is_train else osp.join(
|
206 |
+
out_mask_dir, 'test2014',
|
207 |
+
maskpath.split('.')[0] + '_labelTrainIds.png')
|
208 |
+
Image.fromarray(mask_copy).save(seg_filename, 'PNG')
|
209 |
+
|
210 |
+
|
211 |
+
def generate_coco_list(folder):
|
212 |
+
train_list = osp.join(folder, 'imageLists', 'train.txt')
|
213 |
+
test_list = osp.join(folder, 'imageLists', 'test.txt')
|
214 |
+
train_paths = []
|
215 |
+
test_paths = []
|
216 |
+
|
217 |
+
with open(train_list) as f:
|
218 |
+
for filename in f:
|
219 |
+
basename = filename.strip()
|
220 |
+
imgpath = basename + '.jpg'
|
221 |
+
maskpath = basename + '.mat'
|
222 |
+
train_paths.append((imgpath, maskpath))
|
223 |
+
|
224 |
+
with open(test_list) as f:
|
225 |
+
for filename in f:
|
226 |
+
basename = filename.strip()
|
227 |
+
imgpath = basename + '.jpg'
|
228 |
+
maskpath = basename + '.mat'
|
229 |
+
test_paths.append((imgpath, maskpath))
|
230 |
+
|
231 |
+
return train_paths, test_paths
|
232 |
+
|
233 |
+
|
234 |
+
def parse_args():
|
235 |
+
parser = argparse.ArgumentParser(
|
236 |
+
description=\
|
237 |
+
'Convert COCO Stuff 10k annotations to mmsegmentation format') # noqa
|
238 |
+
parser.add_argument('coco_path', help='coco stuff path')
|
239 |
+
parser.add_argument('-o', '--out_dir', help='output path')
|
240 |
+
parser.add_argument(
|
241 |
+
'--nproc', default=16, type=int, help='number of process')
|
242 |
+
args = parser.parse_args()
|
243 |
+
return args
|
244 |
+
|
245 |
+
|
246 |
+
def main():
|
247 |
+
args = parse_args()
|
248 |
+
coco_path = args.coco_path
|
249 |
+
nproc = args.nproc
|
250 |
+
|
251 |
+
out_dir = args.out_dir or coco_path
|
252 |
+
out_img_dir = osp.join(out_dir, 'images')
|
253 |
+
out_mask_dir = osp.join(out_dir, 'annotations')
|
254 |
+
|
255 |
+
mkdir_or_exist(osp.join(out_img_dir, 'train2014'))
|
256 |
+
mkdir_or_exist(osp.join(out_img_dir, 'test2014'))
|
257 |
+
mkdir_or_exist(osp.join(out_mask_dir, 'train2014'))
|
258 |
+
mkdir_or_exist(osp.join(out_mask_dir, 'test2014'))
|
259 |
+
|
260 |
+
train_list, test_list = generate_coco_list(coco_path)
|
261 |
+
assert (len(train_list) +
|
262 |
+
len(test_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format(
|
263 |
+
len(train_list), len(test_list))
|
264 |
+
|
265 |
+
if args.nproc > 1:
|
266 |
+
track_parallel_progress(
|
267 |
+
partial(
|
268 |
+
convert_to_trainID,
|
269 |
+
in_img_dir=osp.join(coco_path, 'images'),
|
270 |
+
in_ann_dir=osp.join(coco_path, 'annotations'),
|
271 |
+
out_img_dir=out_img_dir,
|
272 |
+
out_mask_dir=out_mask_dir,
|
273 |
+
is_train=True),
|
274 |
+
train_list,
|
275 |
+
nproc=nproc)
|
276 |
+
track_parallel_progress(
|
277 |
+
partial(
|
278 |
+
convert_to_trainID,
|
279 |
+
in_img_dir=osp.join(coco_path, 'images'),
|
280 |
+
in_ann_dir=osp.join(coco_path, 'annotations'),
|
281 |
+
out_img_dir=out_img_dir,
|
282 |
+
out_mask_dir=out_mask_dir,
|
283 |
+
is_train=False),
|
284 |
+
test_list,
|
285 |
+
nproc=nproc)
|
286 |
+
else:
|
287 |
+
track_progress(
|
288 |
+
partial(
|
289 |
+
convert_to_trainID,
|
290 |
+
in_img_dir=osp.join(coco_path, 'images'),
|
291 |
+
in_ann_dir=osp.join(coco_path, 'annotations'),
|
292 |
+
out_img_dir=out_img_dir,
|
293 |
+
out_mask_dir=out_mask_dir,
|
294 |
+
is_train=True), train_list)
|
295 |
+
track_progress(
|
296 |
+
partial(
|
297 |
+
convert_to_trainID,
|
298 |
+
in_img_dir=osp.join(coco_path, 'images'),
|
299 |
+
in_ann_dir=osp.join(coco_path, 'annotations'),
|
300 |
+
out_img_dir=out_img_dir,
|
301 |
+
out_mask_dir=out_mask_dir,
|
302 |
+
is_train=False), test_list)
|
303 |
+
|
304 |
+
print('Done!')
|
305 |
+
|
306 |
+
|
307 |
+
if __name__ == '__main__':
|
308 |
+
main()
|
tools/dataset_converters/coco_stuff164k.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import argparse
|
3 |
+
import os.path as osp
|
4 |
+
import shutil
|
5 |
+
from functools import partial
|
6 |
+
from glob import glob
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
from mmengine.utils import (mkdir_or_exist, track_parallel_progress,
|
10 |
+
track_progress)
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
COCO_LEN = 123287
|
14 |
+
|
15 |
+
clsID_to_trID = {
|
16 |
+
0: 0,
|
17 |
+
1: 1,
|
18 |
+
2: 2,
|
19 |
+
3: 3,
|
20 |
+
4: 4,
|
21 |
+
5: 5,
|
22 |
+
6: 6,
|
23 |
+
7: 7,
|
24 |
+
8: 8,
|
25 |
+
9: 9,
|
26 |
+
10: 10,
|
27 |
+
12: 11,
|
28 |
+
13: 12,
|
29 |
+
14: 13,
|
30 |
+
15: 14,
|
31 |
+
16: 15,
|
32 |
+
17: 16,
|
33 |
+
18: 17,
|
34 |
+
19: 18,
|
35 |
+
20: 19,
|
36 |
+
21: 20,
|
37 |
+
22: 21,
|
38 |
+
23: 22,
|
39 |
+
24: 23,
|
40 |
+
26: 24,
|
41 |
+
27: 25,
|
42 |
+
30: 26,
|
43 |
+
31: 27,
|
44 |
+
32: 28,
|
45 |
+
33: 29,
|
46 |
+
34: 30,
|
47 |
+
35: 31,
|
48 |
+
36: 32,
|
49 |
+
37: 33,
|
50 |
+
38: 34,
|
51 |
+
39: 35,
|
52 |
+
40: 36,
|
53 |
+
41: 37,
|
54 |
+
42: 38,
|
55 |
+
43: 39,
|
56 |
+
45: 40,
|
57 |
+
46: 41,
|
58 |
+
47: 42,
|
59 |
+
48: 43,
|
60 |
+
49: 44,
|
61 |
+
50: 45,
|
62 |
+
51: 46,
|
63 |
+
52: 47,
|
64 |
+
53: 48,
|
65 |
+
54: 49,
|
66 |
+
55: 50,
|
67 |
+
56: 51,
|
68 |
+
57: 52,
|
69 |
+
58: 53,
|
70 |
+
59: 54,
|
71 |
+
60: 55,
|
72 |
+
61: 56,
|
73 |
+
62: 57,
|
74 |
+
63: 58,
|
75 |
+
64: 59,
|
76 |
+
66: 60,
|
77 |
+
69: 61,
|
78 |
+
71: 62,
|
79 |
+
72: 63,
|
80 |
+
73: 64,
|
81 |
+
74: 65,
|
82 |
+
75: 66,
|
83 |
+
76: 67,
|
84 |
+
77: 68,
|
85 |
+
78: 69,
|
86 |
+
79: 70,
|
87 |
+
80: 71,
|
88 |
+
81: 72,
|
89 |
+
83: 73,
|
90 |
+
84: 74,
|
91 |
+
85: 75,
|
92 |
+
86: 76,
|
93 |
+
87: 77,
|
94 |
+
88: 78,
|
95 |
+
89: 79,
|
96 |
+
91: 80,
|
97 |
+
92: 81,
|
98 |
+
93: 82,
|
99 |
+
94: 83,
|
100 |
+
95: 84,
|
101 |
+
96: 85,
|
102 |
+
97: 86,
|
103 |
+
98: 87,
|
104 |
+
99: 88,
|
105 |
+
100: 89,
|
106 |
+
101: 90,
|
107 |
+
102: 91,
|
108 |
+
103: 92,
|
109 |
+
104: 93,
|
110 |
+
105: 94,
|
111 |
+
106: 95,
|
112 |
+
107: 96,
|
113 |
+
108: 97,
|
114 |
+
109: 98,
|
115 |
+
110: 99,
|
116 |
+
111: 100,
|
117 |
+
112: 101,
|
118 |
+
113: 102,
|
119 |
+
114: 103,
|
120 |
+
115: 104,
|
121 |
+
116: 105,
|
122 |
+
117: 106,
|
123 |
+
118: 107,
|
124 |
+
119: 108,
|
125 |
+
120: 109,
|
126 |
+
121: 110,
|
127 |
+
122: 111,
|
128 |
+
123: 112,
|
129 |
+
124: 113,
|
130 |
+
125: 114,
|
131 |
+
126: 115,
|
132 |
+
127: 116,
|
133 |
+
128: 117,
|
134 |
+
129: 118,
|
135 |
+
130: 119,
|
136 |
+
131: 120,
|
137 |
+
132: 121,
|
138 |
+
133: 122,
|
139 |
+
134: 123,
|
140 |
+
135: 124,
|
141 |
+
136: 125,
|
142 |
+
137: 126,
|
143 |
+
138: 127,
|
144 |
+
139: 128,
|
145 |
+
140: 129,
|
146 |
+
141: 130,
|
147 |
+
142: 131,
|
148 |
+
143: 132,
|
149 |
+
144: 133,
|
150 |
+
145: 134,
|
151 |
+
146: 135,
|
152 |
+
147: 136,
|
153 |
+
148: 137,
|
154 |
+
149: 138,
|
155 |
+
150: 139,
|
156 |
+
151: 140,
|
157 |
+
152: 141,
|
158 |
+
153: 142,
|
159 |
+
154: 143,
|
160 |
+
155: 144,
|
161 |
+
156: 145,
|
162 |
+
157: 146,
|
163 |
+
158: 147,
|
164 |
+
159: 148,
|
165 |
+
160: 149,
|
166 |
+
161: 150,
|
167 |
+
162: 151,
|
168 |
+
163: 152,
|
169 |
+
164: 153,
|
170 |
+
165: 154,
|
171 |
+
166: 155,
|
172 |
+
167: 156,
|
173 |
+
168: 157,
|
174 |
+
169: 158,
|
175 |
+
170: 159,
|
176 |
+
171: 160,
|
177 |
+
172: 161,
|
178 |
+
173: 162,
|
179 |
+
174: 163,
|
180 |
+
175: 164,
|
181 |
+
176: 165,
|
182 |
+
177: 166,
|
183 |
+
178: 167,
|
184 |
+
179: 168,
|
185 |
+
180: 169,
|
186 |
+
181: 170,
|
187 |
+
255: 255
|
188 |
+
}
|
189 |
+
|
190 |
+
|
191 |
+
def convert_to_trainID(maskpath, out_mask_dir, is_train):
|
192 |
+
mask = np.array(Image.open(maskpath))
|
193 |
+
mask_copy = mask.copy()
|
194 |
+
for clsID, trID in clsID_to_trID.items():
|
195 |
+
mask_copy[mask == clsID] = trID
|
196 |
+
seg_filename = osp.join(
|
197 |
+
out_mask_dir, 'train2017',
|
198 |
+
osp.basename(maskpath).split('.')[0] +
|
199 |
+
'_labelTrainIds.png') if is_train else osp.join(
|
200 |
+
out_mask_dir, 'val2017',
|
201 |
+
osp.basename(maskpath).split('.')[0] + '_labelTrainIds.png')
|
202 |
+
Image.fromarray(mask_copy).save(seg_filename, 'PNG')
|
203 |
+
|
204 |
+
|
205 |
+
def parse_args():
|
206 |
+
parser = argparse.ArgumentParser(
|
207 |
+
description=\
|
208 |
+
'Convert COCO Stuff 164k annotations to mmsegmentation format') # noqa
|
209 |
+
parser.add_argument('coco_path', help='coco stuff path')
|
210 |
+
parser.add_argument('-o', '--out_dir', help='output path')
|
211 |
+
parser.add_argument(
|
212 |
+
'--nproc', default=16, type=int, help='number of process')
|
213 |
+
args = parser.parse_args()
|
214 |
+
return args
|
215 |
+
|
216 |
+
|
217 |
+
def main():
|
218 |
+
args = parse_args()
|
219 |
+
coco_path = args.coco_path
|
220 |
+
nproc = args.nproc
|
221 |
+
|
222 |
+
out_dir = args.out_dir or coco_path
|
223 |
+
out_img_dir = osp.join(out_dir, 'images')
|
224 |
+
out_mask_dir = osp.join(out_dir, 'annotations')
|
225 |
+
|
226 |
+
mkdir_or_exist(osp.join(out_mask_dir, 'train2017'))
|
227 |
+
mkdir_or_exist(osp.join(out_mask_dir, 'val2017'))
|
228 |
+
|
229 |
+
if out_dir != coco_path:
|
230 |
+
shutil.copytree(osp.join(coco_path, 'images'), out_img_dir)
|
231 |
+
|
232 |
+
train_list = glob(osp.join(coco_path, 'annotations', 'train2017', '*.png'))
|
233 |
+
train_list = [file for file in train_list if '_labelTrainIds' not in file]
|
234 |
+
test_list = glob(osp.join(coco_path, 'annotations', 'val2017', '*.png'))
|
235 |
+
test_list = [file for file in test_list if '_labelTrainIds' not in file]
|
236 |
+
assert (len(train_list) +
|
237 |
+
len(test_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format(
|
238 |
+
len(train_list), len(test_list))
|
239 |
+
|
240 |
+
if args.nproc > 1:
|
241 |
+
track_parallel_progress(
|
242 |
+
partial(
|
243 |
+
convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True),
|
244 |
+
train_list,
|
245 |
+
nproc=nproc)
|
246 |
+
track_parallel_progress(
|
247 |
+
partial(
|
248 |
+
convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False),
|
249 |
+
test_list,
|
250 |
+
nproc=nproc)
|
251 |
+
else:
|
252 |
+
track_progress(
|
253 |
+
partial(
|
254 |
+
convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True),
|
255 |
+
train_list)
|
256 |
+
track_progress(
|
257 |
+
partial(
|
258 |
+
convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False),
|
259 |
+
test_list)
|
260 |
+
|
261 |
+
print('Done!')
|
262 |
+
|
263 |
+
|
264 |
+
if __name__ == '__main__':
|
265 |
+
main()
|
tools/dataset_converters/hrf.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
import tempfile
|
6 |
+
import zipfile
|
7 |
+
|
8 |
+
import mmcv
|
9 |
+
from mmengine.utils import mkdir_or_exist
|
10 |
+
|
11 |
+
HRF_LEN = 15
|
12 |
+
TRAINING_LEN = 5
|
13 |
+
|
14 |
+
|
15 |
+
def parse_args():
|
16 |
+
parser = argparse.ArgumentParser(
|
17 |
+
description='Convert HRF dataset to mmsegmentation format')
|
18 |
+
parser.add_argument('healthy_path', help='the path of healthy.zip')
|
19 |
+
parser.add_argument(
|
20 |
+
'healthy_manualsegm_path', help='the path of healthy_manualsegm.zip')
|
21 |
+
parser.add_argument('glaucoma_path', help='the path of glaucoma.zip')
|
22 |
+
parser.add_argument(
|
23 |
+
'glaucoma_manualsegm_path', help='the path of glaucoma_manualsegm.zip')
|
24 |
+
parser.add_argument(
|
25 |
+
'diabetic_retinopathy_path',
|
26 |
+
help='the path of diabetic_retinopathy.zip')
|
27 |
+
parser.add_argument(
|
28 |
+
'diabetic_retinopathy_manualsegm_path',
|
29 |
+
help='the path of diabetic_retinopathy_manualsegm.zip')
|
30 |
+
parser.add_argument('--tmp_dir', help='path of the temporary directory')
|
31 |
+
parser.add_argument('-o', '--out_dir', help='output path')
|
32 |
+
args = parser.parse_args()
|
33 |
+
return args
|
34 |
+
|
35 |
+
|
36 |
+
def main():
|
37 |
+
args = parse_args()
|
38 |
+
images_path = [
|
39 |
+
args.healthy_path, args.glaucoma_path, args.diabetic_retinopathy_path
|
40 |
+
]
|
41 |
+
annotations_path = [
|
42 |
+
args.healthy_manualsegm_path, args.glaucoma_manualsegm_path,
|
43 |
+
args.diabetic_retinopathy_manualsegm_path
|
44 |
+
]
|
45 |
+
if args.out_dir is None:
|
46 |
+
out_dir = osp.join('data', 'HRF')
|
47 |
+
else:
|
48 |
+
out_dir = args.out_dir
|
49 |
+
|
50 |
+
print('Making directories...')
|
51 |
+
mkdir_or_exist(out_dir)
|
52 |
+
mkdir_or_exist(osp.join(out_dir, 'images'))
|
53 |
+
mkdir_or_exist(osp.join(out_dir, 'images', 'training'))
|
54 |
+
mkdir_or_exist(osp.join(out_dir, 'images', 'validation'))
|
55 |
+
mkdir_or_exist(osp.join(out_dir, 'annotations'))
|
56 |
+
mkdir_or_exist(osp.join(out_dir, 'annotations', 'training'))
|
57 |
+
mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation'))
|
58 |
+
|
59 |
+
print('Generating images...')
|
60 |
+
for now_path in images_path:
|
61 |
+
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
|
62 |
+
zip_file = zipfile.ZipFile(now_path)
|
63 |
+
zip_file.extractall(tmp_dir)
|
64 |
+
|
65 |
+
assert len(os.listdir(tmp_dir)) == HRF_LEN, \
|
66 |
+
f'len(os.listdir(tmp_dir)) != {HRF_LEN}'
|
67 |
+
|
68 |
+
for filename in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]:
|
69 |
+
img = mmcv.imread(osp.join(tmp_dir, filename))
|
70 |
+
mmcv.imwrite(
|
71 |
+
img,
|
72 |
+
osp.join(out_dir, 'images', 'training',
|
73 |
+
osp.splitext(filename)[0] + '.png'))
|
74 |
+
for filename in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]:
|
75 |
+
img = mmcv.imread(osp.join(tmp_dir, filename))
|
76 |
+
mmcv.imwrite(
|
77 |
+
img,
|
78 |
+
osp.join(out_dir, 'images', 'validation',
|
79 |
+
osp.splitext(filename)[0] + '.png'))
|
80 |
+
|
81 |
+
print('Generating annotations...')
|
82 |
+
for now_path in annotations_path:
|
83 |
+
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
|
84 |
+
zip_file = zipfile.ZipFile(now_path)
|
85 |
+
zip_file.extractall(tmp_dir)
|
86 |
+
|
87 |
+
assert len(os.listdir(tmp_dir)) == HRF_LEN, \
|
88 |
+
f'len(os.listdir(tmp_dir)) != {HRF_LEN}'
|
89 |
+
|
90 |
+
for filename in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]:
|
91 |
+
img = mmcv.imread(osp.join(tmp_dir, filename))
|
92 |
+
# The annotation img should be divided by 128, because some of
|
93 |
+
# the annotation imgs are not standard. We should set a
|
94 |
+
# threshold to convert the nonstandard annotation imgs. The
|
95 |
+
# value divided by 128 is equivalent to '1 if value >= 128
|
96 |
+
# else 0'
|
97 |
+
mmcv.imwrite(
|
98 |
+
img[:, :, 0] // 128,
|
99 |
+
osp.join(out_dir, 'annotations', 'training',
|
100 |
+
osp.splitext(filename)[0] + '.png'))
|
101 |
+
for filename in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]:
|
102 |
+
img = mmcv.imread(osp.join(tmp_dir, filename))
|
103 |
+
mmcv.imwrite(
|
104 |
+
img[:, :, 0] // 128,
|
105 |
+
osp.join(out_dir, 'annotations', 'validation',
|
106 |
+
osp.splitext(filename)[0] + '.png'))
|
107 |
+
|
108 |
+
print('Done!')
|
109 |
+
|
110 |
+
|
111 |
+
if __name__ == '__main__':
|
112 |
+
main()
|
tools/dataset_converters/isaid.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import argparse
|
3 |
+
import glob
|
4 |
+
import os
|
5 |
+
import os.path as osp
|
6 |
+
import shutil
|
7 |
+
import tempfile
|
8 |
+
import zipfile
|
9 |
+
|
10 |
+
import mmcv
|
11 |
+
import numpy as np
|
12 |
+
from mmengine.utils import ProgressBar, mkdir_or_exist
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
iSAID_palette = \
|
16 |
+
{
|
17 |
+
0: (0, 0, 0),
|
18 |
+
1: (0, 0, 63),
|
19 |
+
2: (0, 63, 63),
|
20 |
+
3: (0, 63, 0),
|
21 |
+
4: (0, 63, 127),
|
22 |
+
5: (0, 63, 191),
|
23 |
+
6: (0, 63, 255),
|
24 |
+
7: (0, 127, 63),
|
25 |
+
8: (0, 127, 127),
|
26 |
+
9: (0, 0, 127),
|
27 |
+
10: (0, 0, 191),
|
28 |
+
11: (0, 0, 255),
|
29 |
+
12: (0, 191, 127),
|
30 |
+
13: (0, 127, 191),
|
31 |
+
14: (0, 127, 255),
|
32 |
+
15: (0, 100, 155)
|
33 |
+
}
|
34 |
+
|
35 |
+
iSAID_invert_palette = {v: k for k, v in iSAID_palette.items()}
|
36 |
+
|
37 |
+
|
38 |
+
def iSAID_convert_from_color(arr_3d, palette=iSAID_invert_palette):
|
39 |
+
"""RGB-color encoding to grayscale labels."""
|
40 |
+
arr_2d = np.zeros((arr_3d.shape[0], arr_3d.shape[1]), dtype=np.uint8)
|
41 |
+
|
42 |
+
for c, i in palette.items():
|
43 |
+
m = np.all(arr_3d == np.array(c).reshape(1, 1, 3), axis=2)
|
44 |
+
arr_2d[m] = i
|
45 |
+
|
46 |
+
return arr_2d
|
47 |
+
|
48 |
+
|
49 |
+
def slide_crop_image(src_path, out_dir, mode, patch_H, patch_W, overlap):
|
50 |
+
img = np.asarray(Image.open(src_path).convert('RGB'))
|
51 |
+
|
52 |
+
img_H, img_W, _ = img.shape
|
53 |
+
|
54 |
+
if img_H < patch_H and img_W > patch_W:
|
55 |
+
|
56 |
+
img = mmcv.impad(img, shape=(patch_H, img_W), pad_val=0)
|
57 |
+
|
58 |
+
img_H, img_W, _ = img.shape
|
59 |
+
|
60 |
+
elif img_H > patch_H and img_W < patch_W:
|
61 |
+
|
62 |
+
img = mmcv.impad(img, shape=(img_H, patch_W), pad_val=0)
|
63 |
+
|
64 |
+
img_H, img_W, _ = img.shape
|
65 |
+
|
66 |
+
elif img_H < patch_H and img_W < patch_W:
|
67 |
+
|
68 |
+
img = mmcv.impad(img, shape=(patch_H, patch_W), pad_val=0)
|
69 |
+
|
70 |
+
img_H, img_W, _ = img.shape
|
71 |
+
|
72 |
+
for x in range(0, img_W, patch_W - overlap):
|
73 |
+
for y in range(0, img_H, patch_H - overlap):
|
74 |
+
x_str = x
|
75 |
+
x_end = x + patch_W
|
76 |
+
if x_end > img_W:
|
77 |
+
diff_x = x_end - img_W
|
78 |
+
x_str -= diff_x
|
79 |
+
x_end = img_W
|
80 |
+
y_str = y
|
81 |
+
y_end = y + patch_H
|
82 |
+
if y_end > img_H:
|
83 |
+
diff_y = y_end - img_H
|
84 |
+
y_str -= diff_y
|
85 |
+
y_end = img_H
|
86 |
+
|
87 |
+
img_patch = img[y_str:y_end, x_str:x_end, :]
|
88 |
+
img_patch = Image.fromarray(img_patch.astype(np.uint8))
|
89 |
+
image = osp.basename(src_path).split('.')[0] + '_' + str(
|
90 |
+
y_str) + '_' + str(y_end) + '_' + str(x_str) + '_' + str(
|
91 |
+
x_end) + '.png'
|
92 |
+
# print(image)
|
93 |
+
save_path_image = osp.join(out_dir, 'img_dir', mode, str(image))
|
94 |
+
img_patch.save(save_path_image, format='BMP')
|
95 |
+
|
96 |
+
|
97 |
+
def slide_crop_label(src_path, out_dir, mode, patch_H, patch_W, overlap):
|
98 |
+
label = mmcv.imread(src_path, channel_order='rgb')
|
99 |
+
label = iSAID_convert_from_color(label)
|
100 |
+
img_H, img_W = label.shape
|
101 |
+
|
102 |
+
if img_H < patch_H and img_W > patch_W:
|
103 |
+
|
104 |
+
label = mmcv.impad(label, shape=(patch_H, img_W), pad_val=255)
|
105 |
+
|
106 |
+
img_H = patch_H
|
107 |
+
|
108 |
+
elif img_H > patch_H and img_W < patch_W:
|
109 |
+
|
110 |
+
label = mmcv.impad(label, shape=(img_H, patch_W), pad_val=255)
|
111 |
+
|
112 |
+
img_W = patch_W
|
113 |
+
|
114 |
+
elif img_H < patch_H and img_W < patch_W:
|
115 |
+
|
116 |
+
label = mmcv.impad(label, shape=(patch_H, patch_W), pad_val=255)
|
117 |
+
|
118 |
+
img_H = patch_H
|
119 |
+
img_W = patch_W
|
120 |
+
|
121 |
+
for x in range(0, img_W, patch_W - overlap):
|
122 |
+
for y in range(0, img_H, patch_H - overlap):
|
123 |
+
x_str = x
|
124 |
+
x_end = x + patch_W
|
125 |
+
if x_end > img_W:
|
126 |
+
diff_x = x_end - img_W
|
127 |
+
x_str -= diff_x
|
128 |
+
x_end = img_W
|
129 |
+
y_str = y
|
130 |
+
y_end = y + patch_H
|
131 |
+
if y_end > img_H:
|
132 |
+
diff_y = y_end - img_H
|
133 |
+
y_str -= diff_y
|
134 |
+
y_end = img_H
|
135 |
+
|
136 |
+
lab_patch = label[y_str:y_end, x_str:x_end]
|
137 |
+
lab_patch = Image.fromarray(lab_patch.astype(np.uint8), mode='P')
|
138 |
+
|
139 |
+
image = osp.basename(src_path).split('.')[0].split(
|
140 |
+
'_')[0] + '_' + str(y_str) + '_' + str(y_end) + '_' + str(
|
141 |
+
x_str) + '_' + str(x_end) + '_instance_color_RGB' + '.png'
|
142 |
+
lab_patch.save(osp.join(out_dir, 'ann_dir', mode, str(image)))
|
143 |
+
|
144 |
+
|
145 |
+
def parse_args():
|
146 |
+
parser = argparse.ArgumentParser(
|
147 |
+
description='Convert iSAID dataset to mmsegmentation format')
|
148 |
+
parser.add_argument('dataset_path', help='iSAID folder path')
|
149 |
+
parser.add_argument('--tmp_dir', help='path of the temporary directory')
|
150 |
+
parser.add_argument('-o', '--out_dir', help='output path')
|
151 |
+
|
152 |
+
parser.add_argument(
|
153 |
+
'--patch_width',
|
154 |
+
default=896,
|
155 |
+
type=int,
|
156 |
+
help='Width of the cropped image patch')
|
157 |
+
parser.add_argument(
|
158 |
+
'--patch_height',
|
159 |
+
default=896,
|
160 |
+
type=int,
|
161 |
+
help='Height of the cropped image patch')
|
162 |
+
parser.add_argument(
|
163 |
+
'--overlap_area', default=384, type=int, help='Overlap area')
|
164 |
+
args = parser.parse_args()
|
165 |
+
return args
|
166 |
+
|
167 |
+
|
168 |
+
def main():
|
169 |
+
args = parse_args()
|
170 |
+
dataset_path = args.dataset_path
|
171 |
+
# image patch width and height
|
172 |
+
patch_H, patch_W = args.patch_width, args.patch_height
|
173 |
+
|
174 |
+
overlap = args.overlap_area # overlap area
|
175 |
+
|
176 |
+
if args.out_dir is None:
|
177 |
+
out_dir = osp.join('data', 'iSAID')
|
178 |
+
else:
|
179 |
+
out_dir = args.out_dir
|
180 |
+
|
181 |
+
print('Making directories...')
|
182 |
+
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train'))
|
183 |
+
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val'))
|
184 |
+
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'test'))
|
185 |
+
|
186 |
+
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train'))
|
187 |
+
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val'))
|
188 |
+
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'test'))
|
189 |
+
|
190 |
+
assert os.path.exists(os.path.join(dataset_path, 'train')), \
|
191 |
+
f'train is not in {dataset_path}'
|
192 |
+
assert os.path.exists(os.path.join(dataset_path, 'val')), \
|
193 |
+
f'val is not in {dataset_path}'
|
194 |
+
assert os.path.exists(os.path.join(dataset_path, 'test')), \
|
195 |
+
f'test is not in {dataset_path}'
|
196 |
+
|
197 |
+
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
|
198 |
+
for dataset_mode in ['train', 'val', 'test']:
|
199 |
+
|
200 |
+
# for dataset_mode in [ 'test']:
|
201 |
+
print(f'Extracting {dataset_mode}ing.zip...')
|
202 |
+
img_zipp_list = glob.glob(
|
203 |
+
os.path.join(dataset_path, dataset_mode, 'images', '*.zip'))
|
204 |
+
print('Find the data', img_zipp_list)
|
205 |
+
for img_zipp in img_zipp_list:
|
206 |
+
zip_file = zipfile.ZipFile(img_zipp)
|
207 |
+
zip_file.extractall(os.path.join(tmp_dir, dataset_mode, 'img'))
|
208 |
+
src_path_list = glob.glob(
|
209 |
+
os.path.join(tmp_dir, dataset_mode, 'img', 'images', '*.png'))
|
210 |
+
|
211 |
+
src_prog_bar = ProgressBar(len(src_path_list))
|
212 |
+
for i, img_path in enumerate(src_path_list):
|
213 |
+
if dataset_mode != 'test':
|
214 |
+
slide_crop_image(img_path, out_dir, dataset_mode, patch_H,
|
215 |
+
patch_W, overlap)
|
216 |
+
|
217 |
+
else:
|
218 |
+
shutil.move(img_path,
|
219 |
+
os.path.join(out_dir, 'img_dir', dataset_mode))
|
220 |
+
src_prog_bar.update()
|
221 |
+
|
222 |
+
if dataset_mode != 'test':
|
223 |
+
label_zipp_list = glob.glob(
|
224 |
+
os.path.join(dataset_path, dataset_mode, 'Semantic_masks',
|
225 |
+
'*.zip'))
|
226 |
+
for label_zipp in label_zipp_list:
|
227 |
+
zip_file = zipfile.ZipFile(label_zipp)
|
228 |
+
zip_file.extractall(
|
229 |
+
os.path.join(tmp_dir, dataset_mode, 'lab'))
|
230 |
+
|
231 |
+
lab_path_list = glob.glob(
|
232 |
+
os.path.join(tmp_dir, dataset_mode, 'lab', 'images',
|
233 |
+
'*.png'))
|
234 |
+
lab_prog_bar = ProgressBar(len(lab_path_list))
|
235 |
+
for i, lab_path in enumerate(lab_path_list):
|
236 |
+
slide_crop_label(lab_path, out_dir, dataset_mode, patch_H,
|
237 |
+
patch_W, overlap)
|
238 |
+
lab_prog_bar.update()
|
239 |
+
|
240 |
+
print('Removing the temporary files...')
|
241 |
+
|
242 |
+
print('Done!')
|
243 |
+
|
244 |
+
|
245 |
+
if __name__ == '__main__':
|
246 |
+
main()
|
tools/dataset_converters/levircd.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import argparse
|
3 |
+
import glob
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
import os.path as osp
|
7 |
+
|
8 |
+
import mmcv
|
9 |
+
import numpy as np
|
10 |
+
from mmengine.utils import ProgressBar
|
11 |
+
|
12 |
+
|
13 |
+
def parse_args():
|
14 |
+
parser = argparse.ArgumentParser(
|
15 |
+
description='Convert levir-cd dataset to mmsegmentation format')
|
16 |
+
parser.add_argument('--dataset_path', help='potsdam folder path')
|
17 |
+
parser.add_argument('-o', '--out_dir', help='output path')
|
18 |
+
parser.add_argument(
|
19 |
+
'--clip_size',
|
20 |
+
type=int,
|
21 |
+
help='clipped size of image after preparation',
|
22 |
+
default=256)
|
23 |
+
parser.add_argument(
|
24 |
+
'--stride_size',
|
25 |
+
type=int,
|
26 |
+
help='stride of clipping original images',
|
27 |
+
default=256)
|
28 |
+
args = parser.parse_args()
|
29 |
+
return args
|
30 |
+
|
31 |
+
|
32 |
+
def main():
|
33 |
+
args = parse_args()
|
34 |
+
input_folder = args.dataset_path
|
35 |
+
png_files = glob.glob(
|
36 |
+
os.path.join(input_folder, '**/*.png'), recursive=True)
|
37 |
+
output_folder = args.out_dir
|
38 |
+
prog_bar = ProgressBar(len(png_files))
|
39 |
+
for png_file in png_files:
|
40 |
+
new_path = os.path.join(
|
41 |
+
output_folder,
|
42 |
+
os.path.relpath(os.path.dirname(png_file), input_folder))
|
43 |
+
os.makedirs(os.path.dirname(new_path), exist_ok=True)
|
44 |
+
label = False
|
45 |
+
if 'label' in png_file:
|
46 |
+
label = True
|
47 |
+
clip_big_image(png_file, new_path, args, label)
|
48 |
+
prog_bar.update()
|
49 |
+
|
50 |
+
|
51 |
+
def clip_big_image(image_path, clip_save_dir, args, to_label=False):
|
52 |
+
image = mmcv.imread(image_path)
|
53 |
+
|
54 |
+
h, w, c = image.shape
|
55 |
+
clip_size = args.clip_size
|
56 |
+
stride_size = args.stride_size
|
57 |
+
|
58 |
+
num_rows = math.ceil((h - clip_size) / stride_size) if math.ceil(
|
59 |
+
(h - clip_size) /
|
60 |
+
stride_size) * stride_size + clip_size >= h else math.ceil(
|
61 |
+
(h - clip_size) / stride_size) + 1
|
62 |
+
num_cols = math.ceil((w - clip_size) / stride_size) if math.ceil(
|
63 |
+
(w - clip_size) /
|
64 |
+
stride_size) * stride_size + clip_size >= w else math.ceil(
|
65 |
+
(w - clip_size) / stride_size) + 1
|
66 |
+
|
67 |
+
x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1))
|
68 |
+
xmin = x * clip_size
|
69 |
+
ymin = y * clip_size
|
70 |
+
|
71 |
+
xmin = xmin.ravel()
|
72 |
+
ymin = ymin.ravel()
|
73 |
+
xmin_offset = np.where(xmin + clip_size > w, w - xmin - clip_size,
|
74 |
+
np.zeros_like(xmin))
|
75 |
+
ymin_offset = np.where(ymin + clip_size > h, h - ymin - clip_size,
|
76 |
+
np.zeros_like(ymin))
|
77 |
+
boxes = np.stack([
|
78 |
+
xmin + xmin_offset, ymin + ymin_offset,
|
79 |
+
np.minimum(xmin + clip_size, w),
|
80 |
+
np.minimum(ymin + clip_size, h)
|
81 |
+
],
|
82 |
+
axis=1)
|
83 |
+
|
84 |
+
if to_label:
|
85 |
+
image[image == 255] = 1
|
86 |
+
image = image[:, :, 0]
|
87 |
+
for box in boxes:
|
88 |
+
start_x, start_y, end_x, end_y = box
|
89 |
+
clipped_image = image[start_y:end_y, start_x:end_x] \
|
90 |
+
if to_label else image[start_y:end_y, start_x:end_x, :]
|
91 |
+
idx = osp.basename(image_path).split('.')[0]
|
92 |
+
mmcv.imwrite(
|
93 |
+
clipped_image.astype(np.uint8),
|
94 |
+
osp.join(clip_save_dir,
|
95 |
+
f'{idx}_{start_x}_{start_y}_{end_x}_{end_y}.png'))
|
96 |
+
|
97 |
+
|
98 |
+
if __name__ == '__main__':
|
99 |
+
main()
|
tools/dataset_converters/loveda.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
import shutil
|
6 |
+
import tempfile
|
7 |
+
import zipfile
|
8 |
+
|
9 |
+
from mmengine.utils import mkdir_or_exist
|
10 |
+
|
11 |
+
|
12 |
+
def parse_args():
|
13 |
+
parser = argparse.ArgumentParser(
|
14 |
+
description='Convert LoveDA dataset to mmsegmentation format')
|
15 |
+
parser.add_argument('dataset_path', help='LoveDA folder path')
|
16 |
+
parser.add_argument('--tmp_dir', help='path of the temporary directory')
|
17 |
+
parser.add_argument('-o', '--out_dir', help='output path')
|
18 |
+
args = parser.parse_args()
|
19 |
+
return args
|
20 |
+
|
21 |
+
|
22 |
+
def main():
|
23 |
+
args = parse_args()
|
24 |
+
dataset_path = args.dataset_path
|
25 |
+
if args.out_dir is None:
|
26 |
+
out_dir = osp.join('data', 'loveDA')
|
27 |
+
else:
|
28 |
+
out_dir = args.out_dir
|
29 |
+
|
30 |
+
print('Making directories...')
|
31 |
+
mkdir_or_exist(out_dir)
|
32 |
+
mkdir_or_exist(osp.join(out_dir, 'img_dir'))
|
33 |
+
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train'))
|
34 |
+
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val'))
|
35 |
+
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'test'))
|
36 |
+
mkdir_or_exist(osp.join(out_dir, 'ann_dir'))
|
37 |
+
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train'))
|
38 |
+
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val'))
|
39 |
+
|
40 |
+
assert 'Train.zip' in os.listdir(dataset_path), \
|
41 |
+
f'Train.zip is not in {dataset_path}'
|
42 |
+
assert 'Val.zip' in os.listdir(dataset_path), \
|
43 |
+
f'Val.zip is not in {dataset_path}'
|
44 |
+
assert 'Test.zip' in os.listdir(dataset_path), \
|
45 |
+
f'Test.zip is not in {dataset_path}'
|
46 |
+
|
47 |
+
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
|
48 |
+
for dataset in ['Train', 'Val', 'Test']:
|
49 |
+
zip_file = zipfile.ZipFile(
|
50 |
+
os.path.join(dataset_path, dataset + '.zip'))
|
51 |
+
zip_file.extractall(tmp_dir)
|
52 |
+
data_type = dataset.lower()
|
53 |
+
for location in ['Rural', 'Urban']:
|
54 |
+
for image_type in ['images_png', 'masks_png']:
|
55 |
+
if image_type == 'images_png':
|
56 |
+
dst = osp.join(out_dir, 'img_dir', data_type)
|
57 |
+
else:
|
58 |
+
dst = osp.join(out_dir, 'ann_dir', data_type)
|
59 |
+
if dataset == 'Test' and image_type == 'masks_png':
|
60 |
+
continue
|
61 |
+
else:
|
62 |
+
src_dir = osp.join(tmp_dir, dataset, location,
|
63 |
+
image_type)
|
64 |
+
src_lst = os.listdir(src_dir)
|
65 |
+
for file in src_lst:
|
66 |
+
shutil.move(osp.join(src_dir, file), dst)
|
67 |
+
print('Removing the temporary files...')
|
68 |
+
|
69 |
+
print('Done!')
|
70 |
+
|
71 |
+
|
72 |
+
if __name__ == '__main__':
|
73 |
+
main()
|
tools/dataset_converters/nyu.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import argparse
|
3 |
+
import os.path as osp
|
4 |
+
import shutil
|
5 |
+
import tempfile
|
6 |
+
import zipfile
|
7 |
+
|
8 |
+
from mmengine.utils import mkdir_or_exist
|
9 |
+
|
10 |
+
|
11 |
+
def parse_args():
|
12 |
+
parser = argparse.ArgumentParser(
|
13 |
+
description='Convert NYU Depth dataset to mmsegmentation format')
|
14 |
+
parser.add_argument('raw_data', help='the path of raw data')
|
15 |
+
parser.add_argument(
|
16 |
+
'-o', '--out_dir', help='output path', default='./data/nyu')
|
17 |
+
args = parser.parse_args()
|
18 |
+
return args
|
19 |
+
|
20 |
+
|
21 |
+
def reorganize(raw_data_dir: str, out_dir: str):
|
22 |
+
"""Reorganize NYU Depth dataset files into the required directory
|
23 |
+
structure.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
raw_data_dir (str): Path to the raw data directory.
|
27 |
+
out_dir (str): Output directory for the organized dataset.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def move_data(data_list, dst_prefix, fname_func):
|
31 |
+
"""Move data files from source to destination directory.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
data_list (list): List of data file paths.
|
35 |
+
dst_prefix (str): Prefix to be added to destination paths.
|
36 |
+
fname_func (callable): Function to process file names
|
37 |
+
"""
|
38 |
+
for data_item in data_list:
|
39 |
+
data_item = data_item.strip().strip('/')
|
40 |
+
new_item = fname_func(data_item)
|
41 |
+
shutil.move(
|
42 |
+
osp.join(raw_data_dir, data_item),
|
43 |
+
osp.join(out_dir, dst_prefix, new_item))
|
44 |
+
|
45 |
+
def process_phase(phase):
|
46 |
+
"""Process a dataset phase (e.g., 'train' or 'test')."""
|
47 |
+
with open(osp.join(raw_data_dir, f'nyu_{phase}.txt')) as f:
|
48 |
+
data = filter(lambda x: len(x.strip()) > 0, f.readlines())
|
49 |
+
data = map(lambda x: x.split()[:2], data)
|
50 |
+
images, annos = zip(*data)
|
51 |
+
|
52 |
+
move_data(images, f'images/{phase}',
|
53 |
+
lambda x: x.replace('/rgb', ''))
|
54 |
+
move_data(annos, f'annotations/{phase}',
|
55 |
+
lambda x: x.replace('/sync_depth', ''))
|
56 |
+
|
57 |
+
process_phase('train')
|
58 |
+
process_phase('test')
|
59 |
+
|
60 |
+
|
61 |
+
def main():
|
62 |
+
args = parse_args()
|
63 |
+
|
64 |
+
print('Making directories...')
|
65 |
+
mkdir_or_exist(args.out_dir)
|
66 |
+
for subdir in [
|
67 |
+
'images/train', 'images/test', 'annotations/train',
|
68 |
+
'annotations/test'
|
69 |
+
]:
|
70 |
+
mkdir_or_exist(osp.join(args.out_dir, subdir))
|
71 |
+
|
72 |
+
print('Generating images and annotations...')
|
73 |
+
|
74 |
+
if args.raw_data.endswith('.zip'):
|
75 |
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
76 |
+
zip_file = zipfile.ZipFile(args.raw_data)
|
77 |
+
zip_file.extractall(tmp_dir)
|
78 |
+
reorganize(osp.join(tmp_dir, 'nyu'), args.out_dir)
|
79 |
+
else:
|
80 |
+
assert osp.isdir(
|
81 |
+
args.raw_data
|
82 |
+
), 'the argument --raw-data should be either a zip file or directory.'
|
83 |
+
reorganize(args.raw_data, args.out_dir)
|
84 |
+
|
85 |
+
print('Done!')
|
86 |
+
|
87 |
+
|
88 |
+
if __name__ == '__main__':
|
89 |
+
main()
|
tools/dataset_converters/pascal_context.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import argparse
|
3 |
+
import os.path as osp
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from detail import Detail
|
8 |
+
from mmengine.utils import mkdir_or_exist, track_progress
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
_mapping = np.sort(
|
12 |
+
np.array([
|
13 |
+
0, 2, 259, 260, 415, 324, 9, 258, 144, 18, 19, 22, 23, 397, 25, 284,
|
14 |
+
158, 159, 416, 33, 162, 420, 454, 295, 296, 427, 44, 45, 46, 308, 59,
|
15 |
+
440, 445, 31, 232, 65, 354, 424, 68, 326, 72, 458, 34, 207, 80, 355,
|
16 |
+
85, 347, 220, 349, 360, 98, 187, 104, 105, 366, 189, 368, 113, 115
|
17 |
+
]))
|
18 |
+
_key = np.array(range(len(_mapping))).astype('uint8')
|
19 |
+
|
20 |
+
|
21 |
+
def generate_labels(img_id, detail, out_dir):
|
22 |
+
|
23 |
+
def _class_to_index(mask, _mapping, _key):
|
24 |
+
# assert the values
|
25 |
+
values = np.unique(mask)
|
26 |
+
for i in range(len(values)):
|
27 |
+
assert (values[i] in _mapping)
|
28 |
+
index = np.digitize(mask.ravel(), _mapping, right=True)
|
29 |
+
return _key[index].reshape(mask.shape)
|
30 |
+
|
31 |
+
mask = Image.fromarray(
|
32 |
+
_class_to_index(detail.getMask(img_id), _mapping=_mapping, _key=_key))
|
33 |
+
filename = img_id['file_name']
|
34 |
+
mask.save(osp.join(out_dir, filename.replace('jpg', 'png')))
|
35 |
+
return osp.splitext(osp.basename(filename))[0]
|
36 |
+
|
37 |
+
|
38 |
+
def parse_args():
|
39 |
+
parser = argparse.ArgumentParser(
|
40 |
+
description='Convert PASCAL VOC annotations to mmsegmentation format')
|
41 |
+
parser.add_argument('devkit_path', help='pascal voc devkit path')
|
42 |
+
parser.add_argument('json_path', help='annoation json filepath')
|
43 |
+
parser.add_argument('-o', '--out_dir', help='output path')
|
44 |
+
args = parser.parse_args()
|
45 |
+
return args
|
46 |
+
|
47 |
+
|
48 |
+
def main():
|
49 |
+
args = parse_args()
|
50 |
+
devkit_path = args.devkit_path
|
51 |
+
if args.out_dir is None:
|
52 |
+
out_dir = osp.join(devkit_path, 'VOC2010', 'SegmentationClassContext')
|
53 |
+
else:
|
54 |
+
out_dir = args.out_dir
|
55 |
+
json_path = args.json_path
|
56 |
+
mkdir_or_exist(out_dir)
|
57 |
+
img_dir = osp.join(devkit_path, 'VOC2010', 'JPEGImages')
|
58 |
+
|
59 |
+
train_detail = Detail(json_path, img_dir, 'train')
|
60 |
+
train_ids = train_detail.getImgs()
|
61 |
+
|
62 |
+
val_detail = Detail(json_path, img_dir, 'val')
|
63 |
+
val_ids = val_detail.getImgs()
|
64 |
+
|
65 |
+
mkdir_or_exist(
|
66 |
+
osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext'))
|
67 |
+
|
68 |
+
train_list = track_progress(
|
69 |
+
partial(generate_labels, detail=train_detail, out_dir=out_dir),
|
70 |
+
train_ids)
|
71 |
+
with open(
|
72 |
+
osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext',
|
73 |
+
'train.txt'), 'w') as f:
|
74 |
+
f.writelines(line + '\n' for line in sorted(train_list))
|
75 |
+
|
76 |
+
val_list = track_progress(
|
77 |
+
partial(generate_labels, detail=val_detail, out_dir=out_dir), val_ids)
|
78 |
+
with open(
|
79 |
+
osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext',
|
80 |
+
'val.txt'), 'w') as f:
|
81 |
+
f.writelines(line + '\n' for line in sorted(val_list))
|
82 |
+
|
83 |
+
print('Done!')
|
84 |
+
|
85 |
+
|
86 |
+
if __name__ == '__main__':
|
87 |
+
main()
|
tools/dataset_converters/potsdam.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import argparse
|
3 |
+
import glob
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
import os.path as osp
|
7 |
+
import tempfile
|
8 |
+
import zipfile
|
9 |
+
|
10 |
+
import mmcv
|
11 |
+
import numpy as np
|
12 |
+
from mmengine.utils import ProgressBar, mkdir_or_exist
|
13 |
+
|
14 |
+
|
15 |
+
def parse_args():
|
16 |
+
parser = argparse.ArgumentParser(
|
17 |
+
description='Convert potsdam dataset to mmsegmentation format')
|
18 |
+
parser.add_argument('dataset_path', help='potsdam folder path')
|
19 |
+
parser.add_argument('--tmp_dir', help='path of the temporary directory')
|
20 |
+
parser.add_argument('-o', '--out_dir', help='output path')
|
21 |
+
parser.add_argument(
|
22 |
+
'--clip_size',
|
23 |
+
type=int,
|
24 |
+
help='clipped size of image after preparation',
|
25 |
+
default=512)
|
26 |
+
parser.add_argument(
|
27 |
+
'--stride_size',
|
28 |
+
type=int,
|
29 |
+
help='stride of clipping original images',
|
30 |
+
default=256)
|
31 |
+
args = parser.parse_args()
|
32 |
+
return args
|
33 |
+
|
34 |
+
|
35 |
+
def clip_big_image(image_path, clip_save_dir, args, to_label=False):
|
36 |
+
# Original image of Potsdam dataset is very large, thus pre-processing
|
37 |
+
# of them is adopted. Given fixed clip size and stride size to generate
|
38 |
+
# clipped image, the intersection of width and height is determined.
|
39 |
+
# For example, given one 5120 x 5120 original image, the clip size is
|
40 |
+
# 512 and stride size is 256, thus it would generate 20x20 = 400 images
|
41 |
+
# whose size are all 512x512.
|
42 |
+
image = mmcv.imread(image_path)
|
43 |
+
|
44 |
+
h, w, c = image.shape
|
45 |
+
clip_size = args.clip_size
|
46 |
+
stride_size = args.stride_size
|
47 |
+
|
48 |
+
num_rows = math.ceil((h - clip_size) / stride_size) if math.ceil(
|
49 |
+
(h - clip_size) /
|
50 |
+
stride_size) * stride_size + clip_size >= h else math.ceil(
|
51 |
+
(h - clip_size) / stride_size) + 1
|
52 |
+
num_cols = math.ceil((w - clip_size) / stride_size) if math.ceil(
|
53 |
+
(w - clip_size) /
|
54 |
+
stride_size) * stride_size + clip_size >= w else math.ceil(
|
55 |
+
(w - clip_size) / stride_size) + 1
|
56 |
+
|
57 |
+
x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1))
|
58 |
+
xmin = x * clip_size
|
59 |
+
ymin = y * clip_size
|
60 |
+
|
61 |
+
xmin = xmin.ravel()
|
62 |
+
ymin = ymin.ravel()
|
63 |
+
xmin_offset = np.where(xmin + clip_size > w, w - xmin - clip_size,
|
64 |
+
np.zeros_like(xmin))
|
65 |
+
ymin_offset = np.where(ymin + clip_size > h, h - ymin - clip_size,
|
66 |
+
np.zeros_like(ymin))
|
67 |
+
boxes = np.stack([
|
68 |
+
xmin + xmin_offset, ymin + ymin_offset,
|
69 |
+
np.minimum(xmin + clip_size, w),
|
70 |
+
np.minimum(ymin + clip_size, h)
|
71 |
+
],
|
72 |
+
axis=1)
|
73 |
+
|
74 |
+
if to_label:
|
75 |
+
color_map = np.array([[0, 0, 0], [255, 255, 255], [255, 0, 0],
|
76 |
+
[255, 255, 0], [0, 255, 0], [0, 255, 255],
|
77 |
+
[0, 0, 255]])
|
78 |
+
flatten_v = np.matmul(
|
79 |
+
image.reshape(-1, c),
|
80 |
+
np.array([2, 3, 4]).reshape(3, 1))
|
81 |
+
out = np.zeros_like(flatten_v)
|
82 |
+
for idx, class_color in enumerate(color_map):
|
83 |
+
value_idx = np.matmul(class_color,
|
84 |
+
np.array([2, 3, 4]).reshape(3, 1))
|
85 |
+
out[flatten_v == value_idx] = idx
|
86 |
+
image = out.reshape(h, w)
|
87 |
+
|
88 |
+
for box in boxes:
|
89 |
+
start_x, start_y, end_x, end_y = box
|
90 |
+
clipped_image = image[start_y:end_y,
|
91 |
+
start_x:end_x] if to_label else image[
|
92 |
+
start_y:end_y, start_x:end_x, :]
|
93 |
+
idx_i, idx_j = osp.basename(image_path).split('_')[2:4]
|
94 |
+
mmcv.imwrite(
|
95 |
+
clipped_image.astype(np.uint8),
|
96 |
+
osp.join(
|
97 |
+
clip_save_dir,
|
98 |
+
f'{idx_i}_{idx_j}_{start_x}_{start_y}_{end_x}_{end_y}.png'))
|
99 |
+
|
100 |
+
|
101 |
+
def main():
|
102 |
+
args = parse_args()
|
103 |
+
splits = {
|
104 |
+
'train': [
|
105 |
+
'2_10', '2_11', '2_12', '3_10', '3_11', '3_12', '4_10', '4_11',
|
106 |
+
'4_12', '5_10', '5_11', '5_12', '6_10', '6_11', '6_12', '6_7',
|
107 |
+
'6_8', '6_9', '7_10', '7_11', '7_12', '7_7', '7_8', '7_9'
|
108 |
+
],
|
109 |
+
'val': [
|
110 |
+
'5_15', '6_15', '6_13', '3_13', '4_14', '6_14', '5_14', '2_13',
|
111 |
+
'4_15', '2_14', '5_13', '4_13', '3_14', '7_13'
|
112 |
+
]
|
113 |
+
}
|
114 |
+
|
115 |
+
dataset_path = args.dataset_path
|
116 |
+
if args.out_dir is None:
|
117 |
+
out_dir = osp.join('data', 'potsdam')
|
118 |
+
else:
|
119 |
+
out_dir = args.out_dir
|
120 |
+
|
121 |
+
print('Making directories...')
|
122 |
+
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train'))
|
123 |
+
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val'))
|
124 |
+
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train'))
|
125 |
+
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val'))
|
126 |
+
|
127 |
+
zipp_list = glob.glob(os.path.join(dataset_path, '*.zip'))
|
128 |
+
print('Find the data', zipp_list)
|
129 |
+
|
130 |
+
for zipp in zipp_list:
|
131 |
+
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
|
132 |
+
zip_file = zipfile.ZipFile(zipp)
|
133 |
+
zip_file.extractall(tmp_dir)
|
134 |
+
src_path_list = glob.glob(os.path.join(tmp_dir, '*.tif'))
|
135 |
+
if not len(src_path_list):
|
136 |
+
sub_tmp_dir = os.path.join(tmp_dir, os.listdir(tmp_dir)[0])
|
137 |
+
src_path_list = glob.glob(os.path.join(sub_tmp_dir, '*.tif'))
|
138 |
+
|
139 |
+
prog_bar = ProgressBar(len(src_path_list))
|
140 |
+
for i, src_path in enumerate(src_path_list):
|
141 |
+
idx_i, idx_j = osp.basename(src_path).split('_')[2:4]
|
142 |
+
data_type = 'train' if f'{idx_i}_{idx_j}' in splits[
|
143 |
+
'train'] else 'val'
|
144 |
+
if 'label' in src_path:
|
145 |
+
dst_dir = osp.join(out_dir, 'ann_dir', data_type)
|
146 |
+
clip_big_image(src_path, dst_dir, args, to_label=True)
|
147 |
+
else:
|
148 |
+
dst_dir = osp.join(out_dir, 'img_dir', data_type)
|
149 |
+
clip_big_image(src_path, dst_dir, args, to_label=False)
|
150 |
+
prog_bar.update()
|
151 |
+
|
152 |
+
print('Removing the temporary files...')
|
153 |
+
|
154 |
+
print('Done!')
|
155 |
+
|
156 |
+
|
157 |
+
if __name__ == '__main__':
|
158 |
+
main()
|
tools/dataset_converters/refuge.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
import tempfile
|
6 |
+
import zipfile
|
7 |
+
|
8 |
+
import mmcv
|
9 |
+
import numpy as np
|
10 |
+
from mmengine.utils import mkdir_or_exist
|
11 |
+
|
12 |
+
|
13 |
+
def parse_args():
|
14 |
+
parser = argparse.ArgumentParser(
|
15 |
+
description='Convert REFUGE dataset to mmsegmentation format')
|
16 |
+
parser.add_argument('--raw_data_root', help='the root path of raw data')
|
17 |
+
|
18 |
+
parser.add_argument('--tmp_dir', help='path of the temporary directory')
|
19 |
+
parser.add_argument('-o', '--out_dir', help='output path')
|
20 |
+
args = parser.parse_args()
|
21 |
+
return args
|
22 |
+
|
23 |
+
|
24 |
+
def extract_img(root: str,
|
25 |
+
cur_dir: str,
|
26 |
+
out_dir: str,
|
27 |
+
mode: str = 'train',
|
28 |
+
file_type: str = 'img') -> None:
|
29 |
+
"""_summary_
|
30 |
+
|
31 |
+
Args:
|
32 |
+
Args:
|
33 |
+
root (str): root where the extracted data is saved
|
34 |
+
cur_dir (cur_dir): dir where the zip_file exists
|
35 |
+
out_dir (str): root dir where the data is saved
|
36 |
+
|
37 |
+
mode (str, optional): Defaults to 'train'.
|
38 |
+
file_type (str, optional): Defaults to 'img',else to 'mask'.
|
39 |
+
"""
|
40 |
+
zip_file = zipfile.ZipFile(cur_dir)
|
41 |
+
zip_file.extractall(root)
|
42 |
+
for cur_dir, dirs, files in os.walk(root):
|
43 |
+
# filter child dirs and directories with "Illustration" and "MACOSX"
|
44 |
+
if len(dirs) == 0 and \
|
45 |
+
cur_dir.split('\\')[-1].find('Illustration') == -1 and \
|
46 |
+
cur_dir.find('MACOSX') == -1:
|
47 |
+
|
48 |
+
file_names = [
|
49 |
+
file for file in files
|
50 |
+
if file.endswith('.jpg') or file.endswith('.bmp')
|
51 |
+
]
|
52 |
+
for filename in sorted(file_names):
|
53 |
+
img = mmcv.imread(osp.join(cur_dir, filename))
|
54 |
+
|
55 |
+
if file_type == 'annotations':
|
56 |
+
img = img[:, :, 0]
|
57 |
+
img[np.where(img == 0)] = 1
|
58 |
+
img[np.where(img == 128)] = 2
|
59 |
+
img[np.where(img == 255)] = 0
|
60 |
+
mmcv.imwrite(
|
61 |
+
img,
|
62 |
+
osp.join(out_dir, file_type, mode,
|
63 |
+
osp.splitext(filename)[0] + '.png'))
|
64 |
+
|
65 |
+
|
66 |
+
def main():
|
67 |
+
args = parse_args()
|
68 |
+
|
69 |
+
raw_data_root = args.raw_data_root
|
70 |
+
if args.out_dir is None:
|
71 |
+
out_dir = osp.join('./data', 'REFUGE')
|
72 |
+
|
73 |
+
else:
|
74 |
+
out_dir = args.out_dir
|
75 |
+
|
76 |
+
print('Making directories...')
|
77 |
+
mkdir_or_exist(out_dir)
|
78 |
+
mkdir_or_exist(osp.join(out_dir, 'images'))
|
79 |
+
mkdir_or_exist(osp.join(out_dir, 'images', 'training'))
|
80 |
+
mkdir_or_exist(osp.join(out_dir, 'images', 'validation'))
|
81 |
+
mkdir_or_exist(osp.join(out_dir, 'images', 'test'))
|
82 |
+
mkdir_or_exist(osp.join(out_dir, 'annotations'))
|
83 |
+
mkdir_or_exist(osp.join(out_dir, 'annotations', 'training'))
|
84 |
+
mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation'))
|
85 |
+
mkdir_or_exist(osp.join(out_dir, 'annotations', 'test'))
|
86 |
+
|
87 |
+
print('Generating images and annotations...')
|
88 |
+
# process data from the child dir on the first rank
|
89 |
+
cur_dir, dirs, files = list(os.walk(raw_data_root))[0]
|
90 |
+
print('====================')
|
91 |
+
|
92 |
+
files = list(filter(lambda x: x.endswith('.zip'), files))
|
93 |
+
|
94 |
+
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
|
95 |
+
for file in files:
|
96 |
+
# search data folders for training,validation,test
|
97 |
+
mode = list(
|
98 |
+
filter(lambda x: file.lower().find(x) != -1,
|
99 |
+
['training', 'test', 'validation']))[0]
|
100 |
+
file_root = osp.join(tmp_dir, file[:-4])
|
101 |
+
file_type = 'images' if file.find('Anno') == -1 and file.find(
|
102 |
+
'GT') == -1 else 'annotations'
|
103 |
+
extract_img(file_root, osp.join(cur_dir, file), out_dir, mode,
|
104 |
+
file_type)
|
105 |
+
|
106 |
+
print('Done!')
|
107 |
+
|
108 |
+
|
109 |
+
if __name__ == '__main__':
|
110 |
+
main()
|
tools/dataset_converters/stare.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import argparse
|
3 |
+
import gzip
|
4 |
+
import os
|
5 |
+
import os.path as osp
|
6 |
+
import tarfile
|
7 |
+
import tempfile
|
8 |
+
|
9 |
+
import mmcv
|
10 |
+
from mmengine.utils import mkdir_or_exist
|
11 |
+
|
12 |
+
STARE_LEN = 20
|
13 |
+
TRAINING_LEN = 10
|
14 |
+
|
15 |
+
|
16 |
+
def un_gz(src, dst):
|
17 |
+
g_file = gzip.GzipFile(src)
|
18 |
+
with open(dst, 'wb+') as f:
|
19 |
+
f.write(g_file.read())
|
20 |
+
g_file.close()
|
21 |
+
|
22 |
+
|
23 |
+
def parse_args():
|
24 |
+
parser = argparse.ArgumentParser(
|
25 |
+
description='Convert STARE dataset to mmsegmentation format')
|
26 |
+
parser.add_argument('image_path', help='the path of stare-images.tar')
|
27 |
+
parser.add_argument('labels_ah', help='the path of labels-ah.tar')
|
28 |
+
parser.add_argument('labels_vk', help='the path of labels-vk.tar')
|
29 |
+
parser.add_argument('--tmp_dir', help='path of the temporary directory')
|
30 |
+
parser.add_argument('-o', '--out_dir', help='output path')
|
31 |
+
args = parser.parse_args()
|
32 |
+
return args
|
33 |
+
|
34 |
+
|
35 |
+
def main():
|
36 |
+
args = parse_args()
|
37 |
+
image_path = args.image_path
|
38 |
+
labels_ah = args.labels_ah
|
39 |
+
labels_vk = args.labels_vk
|
40 |
+
if args.out_dir is None:
|
41 |
+
out_dir = osp.join('data', 'STARE')
|
42 |
+
else:
|
43 |
+
out_dir = args.out_dir
|
44 |
+
|
45 |
+
print('Making directories...')
|
46 |
+
mkdir_or_exist(out_dir)
|
47 |
+
mkdir_or_exist(osp.join(out_dir, 'images'))
|
48 |
+
mkdir_or_exist(osp.join(out_dir, 'images', 'training'))
|
49 |
+
mkdir_or_exist(osp.join(out_dir, 'images', 'validation'))
|
50 |
+
mkdir_or_exist(osp.join(out_dir, 'annotations'))
|
51 |
+
mkdir_or_exist(osp.join(out_dir, 'annotations', 'training'))
|
52 |
+
mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation'))
|
53 |
+
|
54 |
+
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
|
55 |
+
mkdir_or_exist(osp.join(tmp_dir, 'gz'))
|
56 |
+
mkdir_or_exist(osp.join(tmp_dir, 'files'))
|
57 |
+
|
58 |
+
print('Extracting stare-images.tar...')
|
59 |
+
with tarfile.open(image_path) as f:
|
60 |
+
f.extractall(osp.join(tmp_dir, 'gz'))
|
61 |
+
|
62 |
+
for filename in os.listdir(osp.join(tmp_dir, 'gz')):
|
63 |
+
un_gz(
|
64 |
+
osp.join(tmp_dir, 'gz', filename),
|
65 |
+
osp.join(tmp_dir, 'files',
|
66 |
+
osp.splitext(filename)[0]))
|
67 |
+
|
68 |
+
now_dir = osp.join(tmp_dir, 'files')
|
69 |
+
|
70 |
+
assert len(os.listdir(now_dir)) == STARE_LEN, \
|
71 |
+
f'len(os.listdir(now_dir)) != {STARE_LEN}'
|
72 |
+
|
73 |
+
for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]:
|
74 |
+
img = mmcv.imread(osp.join(now_dir, filename))
|
75 |
+
mmcv.imwrite(
|
76 |
+
img,
|
77 |
+
osp.join(out_dir, 'images', 'training',
|
78 |
+
osp.splitext(filename)[0] + '.png'))
|
79 |
+
|
80 |
+
for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]:
|
81 |
+
img = mmcv.imread(osp.join(now_dir, filename))
|
82 |
+
mmcv.imwrite(
|
83 |
+
img,
|
84 |
+
osp.join(out_dir, 'images', 'validation',
|
85 |
+
osp.splitext(filename)[0] + '.png'))
|
86 |
+
|
87 |
+
print('Removing the temporary files...')
|
88 |
+
|
89 |
+
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
|
90 |
+
mkdir_or_exist(osp.join(tmp_dir, 'gz'))
|
91 |
+
mkdir_or_exist(osp.join(tmp_dir, 'files'))
|
92 |
+
|
93 |
+
print('Extracting labels-ah.tar...')
|
94 |
+
with tarfile.open(labels_ah) as f:
|
95 |
+
f.extractall(osp.join(tmp_dir, 'gz'))
|
96 |
+
|
97 |
+
for filename in os.listdir(osp.join(tmp_dir, 'gz')):
|
98 |
+
un_gz(
|
99 |
+
osp.join(tmp_dir, 'gz', filename),
|
100 |
+
osp.join(tmp_dir, 'files',
|
101 |
+
osp.splitext(filename)[0]))
|
102 |
+
|
103 |
+
now_dir = osp.join(tmp_dir, 'files')
|
104 |
+
|
105 |
+
assert len(os.listdir(now_dir)) == STARE_LEN, \
|
106 |
+
f'len(os.listdir(now_dir)) != {STARE_LEN}'
|
107 |
+
|
108 |
+
for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]:
|
109 |
+
img = mmcv.imread(osp.join(now_dir, filename))
|
110 |
+
# The annotation img should be divided by 128, because some of
|
111 |
+
# the annotation imgs are not standard. We should set a threshold
|
112 |
+
# to convert the nonstandard annotation imgs. The value divided by
|
113 |
+
# 128 equivalent to '1 if value >= 128 else 0'
|
114 |
+
mmcv.imwrite(
|
115 |
+
img[:, :, 0] // 128,
|
116 |
+
osp.join(out_dir, 'annotations', 'training',
|
117 |
+
osp.splitext(filename)[0] + '.png'))
|
118 |
+
|
119 |
+
for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]:
|
120 |
+
img = mmcv.imread(osp.join(now_dir, filename))
|
121 |
+
mmcv.imwrite(
|
122 |
+
img[:, :, 0] // 128,
|
123 |
+
osp.join(out_dir, 'annotations', 'validation',
|
124 |
+
osp.splitext(filename)[0] + '.png'))
|
125 |
+
|
126 |
+
print('Removing the temporary files...')
|
127 |
+
|
128 |
+
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
|
129 |
+
mkdir_or_exist(osp.join(tmp_dir, 'gz'))
|
130 |
+
mkdir_or_exist(osp.join(tmp_dir, 'files'))
|
131 |
+
|
132 |
+
print('Extracting labels-vk.tar...')
|
133 |
+
with tarfile.open(labels_vk) as f:
|
134 |
+
f.extractall(osp.join(tmp_dir, 'gz'))
|
135 |
+
|
136 |
+
for filename in os.listdir(osp.join(tmp_dir, 'gz')):
|
137 |
+
un_gz(
|
138 |
+
osp.join(tmp_dir, 'gz', filename),
|
139 |
+
osp.join(tmp_dir, 'files',
|
140 |
+
osp.splitext(filename)[0]))
|
141 |
+
|
142 |
+
now_dir = osp.join(tmp_dir, 'files')
|
143 |
+
|
144 |
+
assert len(os.listdir(now_dir)) == STARE_LEN, \
|
145 |
+
f'len(os.listdir(now_dir)) != {STARE_LEN}'
|
146 |
+
|
147 |
+
for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]:
|
148 |
+
img = mmcv.imread(osp.join(now_dir, filename))
|
149 |
+
mmcv.imwrite(
|
150 |
+
img[:, :, 0] // 128,
|
151 |
+
osp.join(out_dir, 'annotations', 'training',
|
152 |
+
osp.splitext(filename)[0] + '.png'))
|
153 |
+
|
154 |
+
for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]:
|
155 |
+
img = mmcv.imread(osp.join(now_dir, filename))
|
156 |
+
mmcv.imwrite(
|
157 |
+
img[:, :, 0] // 128,
|
158 |
+
osp.join(out_dir, 'annotations', 'validation',
|
159 |
+
osp.splitext(filename)[0] + '.png'))
|
160 |
+
|
161 |
+
print('Removing the temporary files...')
|
162 |
+
|
163 |
+
print('Done!')
|
164 |
+
|
165 |
+
|
166 |
+
if __name__ == '__main__':
|
167 |
+
main()
|
tools/dataset_converters/synapse.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import argparse
|
3 |
+
import os.path as osp
|
4 |
+
|
5 |
+
import nibabel as nib
|
6 |
+
import numpy as np
|
7 |
+
from mmengine.utils import mkdir_or_exist
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
|
11 |
+
def read_files_from_txt(txt_path):
|
12 |
+
with open(txt_path) as f:
|
13 |
+
files = f.readlines()
|
14 |
+
files = [file.strip() for file in files]
|
15 |
+
return files
|
16 |
+
|
17 |
+
|
18 |
+
def read_nii_file(nii_path):
|
19 |
+
img = nib.load(nii_path).get_fdata()
|
20 |
+
return img
|
21 |
+
|
22 |
+
|
23 |
+
def split_3d_image(img):
|
24 |
+
c, _, _ = img.shape
|
25 |
+
res = []
|
26 |
+
for i in range(c):
|
27 |
+
res.append(img[i, :, :])
|
28 |
+
return res
|
29 |
+
|
30 |
+
|
31 |
+
def label_mapping(label):
|
32 |
+
"""Label mapping from TransUNet paper setting. It only has 9 classes, which
|
33 |
+
are 'background', 'aorta', 'gallbladder', 'left_kidney', 'right_kidney',
|
34 |
+
'liver', 'pancreas', 'spleen', 'stomach', respectively. Other foreground
|
35 |
+
classes in original dataset are all set to background.
|
36 |
+
|
37 |
+
More details could be found here: https://arxiv.org/abs/2102.04306
|
38 |
+
"""
|
39 |
+
maped_label = np.zeros_like(label)
|
40 |
+
maped_label[label == 8] = 1
|
41 |
+
maped_label[label == 4] = 2
|
42 |
+
maped_label[label == 3] = 3
|
43 |
+
maped_label[label == 2] = 4
|
44 |
+
maped_label[label == 6] = 5
|
45 |
+
maped_label[label == 11] = 6
|
46 |
+
maped_label[label == 1] = 7
|
47 |
+
maped_label[label == 7] = 8
|
48 |
+
return maped_label
|
49 |
+
|
50 |
+
|
51 |
+
def pares_args():
|
52 |
+
parser = argparse.ArgumentParser(
|
53 |
+
description='Convert synapse dataset to mmsegmentation format')
|
54 |
+
parser.add_argument(
|
55 |
+
'--dataset-path', type=str, help='synapse dataset path.')
|
56 |
+
parser.add_argument(
|
57 |
+
'--save-path',
|
58 |
+
default='data/synapse',
|
59 |
+
type=str,
|
60 |
+
help='save path of the dataset.')
|
61 |
+
args = parser.parse_args()
|
62 |
+
return args
|
63 |
+
|
64 |
+
|
65 |
+
def main():
|
66 |
+
args = pares_args()
|
67 |
+
dataset_path = args.dataset_path
|
68 |
+
save_path = args.save_path
|
69 |
+
|
70 |
+
if not osp.exists(dataset_path):
|
71 |
+
raise ValueError('The dataset path does not exist. '
|
72 |
+
'Please enter a correct dataset path.')
|
73 |
+
if not osp.exists(osp.join(dataset_path, 'img')) \
|
74 |
+
or not osp.exists(osp.join(dataset_path, 'label')):
|
75 |
+
raise FileNotFoundError('The dataset structure is incorrect. '
|
76 |
+
'Please check your dataset.')
|
77 |
+
|
78 |
+
train_id = read_files_from_txt(osp.join(dataset_path, 'train.txt'))
|
79 |
+
train_id = [idx[3:7] for idx in train_id]
|
80 |
+
|
81 |
+
test_id = read_files_from_txt(osp.join(dataset_path, 'val.txt'))
|
82 |
+
test_id = [idx[3:7] for idx in test_id]
|
83 |
+
|
84 |
+
mkdir_or_exist(osp.join(save_path, 'img_dir/train'))
|
85 |
+
mkdir_or_exist(osp.join(save_path, 'img_dir/val'))
|
86 |
+
mkdir_or_exist(osp.join(save_path, 'ann_dir/train'))
|
87 |
+
mkdir_or_exist(osp.join(save_path, 'ann_dir/val'))
|
88 |
+
|
89 |
+
# It follows data preparation pipeline from here:
|
90 |
+
# https://github.com/Beckschen/TransUNet/tree/main/datasets
|
91 |
+
for i, idx in enumerate(train_id):
|
92 |
+
img_3d = read_nii_file(
|
93 |
+
osp.join(dataset_path, 'img', 'img' + idx + '.nii.gz'))
|
94 |
+
label_3d = read_nii_file(
|
95 |
+
osp.join(dataset_path, 'label', 'label' + idx + '.nii.gz'))
|
96 |
+
|
97 |
+
img_3d = np.clip(img_3d, -125, 275)
|
98 |
+
img_3d = (img_3d + 125) / 400
|
99 |
+
img_3d *= 255
|
100 |
+
img_3d = np.transpose(img_3d, [2, 0, 1])
|
101 |
+
img_3d = np.flip(img_3d, 2)
|
102 |
+
|
103 |
+
label_3d = np.transpose(label_3d, [2, 0, 1])
|
104 |
+
label_3d = np.flip(label_3d, 2)
|
105 |
+
label_3d = label_mapping(label_3d)
|
106 |
+
|
107 |
+
for c in range(img_3d.shape[0]):
|
108 |
+
img = img_3d[c]
|
109 |
+
label = label_3d[c]
|
110 |
+
|
111 |
+
img = Image.fromarray(img).convert('RGB')
|
112 |
+
label = Image.fromarray(label).convert('L')
|
113 |
+
img.save(
|
114 |
+
osp.join(
|
115 |
+
save_path, 'img_dir/train', 'case' + idx.zfill(4) +
|
116 |
+
'_slice' + str(c).zfill(3) + '.jpg'))
|
117 |
+
label.save(
|
118 |
+
osp.join(
|
119 |
+
save_path, 'ann_dir/train', 'case' + idx.zfill(4) +
|
120 |
+
'_slice' + str(c).zfill(3) + '.png'))
|
121 |
+
|
122 |
+
for i, idx in enumerate(test_id):
|
123 |
+
img_3d = read_nii_file(
|
124 |
+
osp.join(dataset_path, 'img', 'img' + idx + '.nii.gz'))
|
125 |
+
label_3d = read_nii_file(
|
126 |
+
osp.join(dataset_path, 'label', 'label' + idx + '.nii.gz'))
|
127 |
+
|
128 |
+
img_3d = np.clip(img_3d, -125, 275)
|
129 |
+
img_3d = (img_3d + 125) / 400
|
130 |
+
img_3d *= 255
|
131 |
+
img_3d = np.transpose(img_3d, [2, 0, 1])
|
132 |
+
img_3d = np.flip(img_3d, 2)
|
133 |
+
|
134 |
+
label_3d = np.transpose(label_3d, [2, 0, 1])
|
135 |
+
label_3d = np.flip(label_3d, 2)
|
136 |
+
label_3d = label_mapping(label_3d)
|
137 |
+
|
138 |
+
for c in range(img_3d.shape[0]):
|
139 |
+
img = img_3d[c]
|
140 |
+
label = label_3d[c]
|
141 |
+
|
142 |
+
img = Image.fromarray(img).convert('RGB')
|
143 |
+
label = Image.fromarray(label).convert('L')
|
144 |
+
img.save(
|
145 |
+
osp.join(
|
146 |
+
save_path, 'img_dir/val', 'case' + idx.zfill(4) +
|
147 |
+
'_slice' + str(c).zfill(3) + '.jpg'))
|
148 |
+
label.save(
|
149 |
+
osp.join(
|
150 |
+
save_path, 'ann_dir/val', 'case' + idx.zfill(4) +
|
151 |
+
'_slice' + str(c).zfill(3) + '.png'))
|
152 |
+
|
153 |
+
|
154 |
+
if __name__ == '__main__':
|
155 |
+
main()
|
tools/dataset_converters/voc_aug.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import argparse
|
3 |
+
import os.path as osp
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from mmengine.utils import mkdir_or_exist, scandir, track_parallel_progress
|
8 |
+
from PIL import Image
|
9 |
+
from scipy.io import loadmat
|
10 |
+
|
11 |
+
AUG_LEN = 10582
|
12 |
+
|
13 |
+
|
14 |
+
def convert_mat(mat_file, in_dir, out_dir):
|
15 |
+
data = loadmat(osp.join(in_dir, mat_file))
|
16 |
+
mask = data['GTcls'][0]['Segmentation'][0].astype(np.uint8)
|
17 |
+
seg_filename = osp.join(out_dir, mat_file.replace('.mat', '.png'))
|
18 |
+
Image.fromarray(mask).save(seg_filename, 'PNG')
|
19 |
+
|
20 |
+
|
21 |
+
def generate_aug_list(merged_list, excluded_list):
|
22 |
+
return list(set(merged_list) - set(excluded_list))
|
23 |
+
|
24 |
+
|
25 |
+
def parse_args():
|
26 |
+
parser = argparse.ArgumentParser(
|
27 |
+
description='Convert PASCAL VOC annotations to mmsegmentation format')
|
28 |
+
parser.add_argument('devkit_path', help='pascal voc devkit path')
|
29 |
+
parser.add_argument('aug_path', help='pascal voc aug path')
|
30 |
+
parser.add_argument('-o', '--out_dir', help='output path')
|
31 |
+
parser.add_argument(
|
32 |
+
'--nproc', default=1, type=int, help='number of process')
|
33 |
+
args = parser.parse_args()
|
34 |
+
return args
|
35 |
+
|
36 |
+
|
37 |
+
def main():
|
38 |
+
args = parse_args()
|
39 |
+
devkit_path = args.devkit_path
|
40 |
+
aug_path = args.aug_path
|
41 |
+
nproc = args.nproc
|
42 |
+
if args.out_dir is None:
|
43 |
+
out_dir = osp.join(devkit_path, 'VOC2012', 'SegmentationClassAug')
|
44 |
+
else:
|
45 |
+
out_dir = args.out_dir
|
46 |
+
mkdir_or_exist(out_dir)
|
47 |
+
in_dir = osp.join(aug_path, 'dataset', 'cls')
|
48 |
+
|
49 |
+
track_parallel_progress(
|
50 |
+
partial(convert_mat, in_dir=in_dir, out_dir=out_dir),
|
51 |
+
list(scandir(in_dir, suffix='.mat')),
|
52 |
+
nproc=nproc)
|
53 |
+
|
54 |
+
full_aug_list = []
|
55 |
+
with open(osp.join(aug_path, 'dataset', 'train.txt')) as f:
|
56 |
+
full_aug_list += [line.strip() for line in f]
|
57 |
+
with open(osp.join(aug_path, 'dataset', 'val.txt')) as f:
|
58 |
+
full_aug_list += [line.strip() for line in f]
|
59 |
+
|
60 |
+
with open(
|
61 |
+
osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
|
62 |
+
'train.txt')) as f:
|
63 |
+
ori_train_list = [line.strip() for line in f]
|
64 |
+
with open(
|
65 |
+
osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
|
66 |
+
'val.txt')) as f:
|
67 |
+
val_list = [line.strip() for line in f]
|
68 |
+
|
69 |
+
aug_train_list = generate_aug_list(ori_train_list + full_aug_list,
|
70 |
+
val_list)
|
71 |
+
assert len(aug_train_list) == AUG_LEN, 'len(aug_train_list) != {}'.format(
|
72 |
+
AUG_LEN)
|
73 |
+
|
74 |
+
with open(
|
75 |
+
osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
|
76 |
+
'trainaug.txt'), 'w') as f:
|
77 |
+
f.writelines(line + '\n' for line in aug_train_list)
|
78 |
+
|
79 |
+
aug_list = generate_aug_list(full_aug_list, ori_train_list + val_list)
|
80 |
+
assert len(aug_list) == AUG_LEN - len(
|
81 |
+
ori_train_list), 'len(aug_list) != {}'.format(AUG_LEN -
|
82 |
+
len(ori_train_list))
|
83 |
+
with open(
|
84 |
+
osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', 'aug.txt'),
|
85 |
+
'w') as f:
|
86 |
+
f.writelines(line + '\n' for line in aug_list)
|
87 |
+
|
88 |
+
print('Done!')
|
89 |
+
|
90 |
+
|
91 |
+
if __name__ == '__main__':
|
92 |
+
main()
|
tools/dataset_tools/create_dataset.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from glob import glob
|
3 |
+
from typing import List, Literal
|
4 |
+
import shutil
|
5 |
+
from PIL import Image
|
6 |
+
import json
|
7 |
+
import numpy as np
|
8 |
+
from rich.progress import track
|
9 |
+
import cv2
|
10 |
+
from vegseg.datasets import GrassDataset
|
11 |
+
from sklearn.model_selection import train_test_split
|
12 |
+
import argparse
|
13 |
+
|
14 |
+
|
15 |
+
def give_color_to_mask(mask: np.ndarray, palette: List[int]) -> Image.Image:
|
16 |
+
"""
|
17 |
+
Convert mask to color image
|
18 |
+
Args:
|
19 |
+
mask (np.ndarray): numpy array of shape (H, W)
|
20 |
+
palette (List[int]): list of RGB values
|
21 |
+
return:
|
22 |
+
color_mask (Image.Image): PIL Image of shape (H, W)
|
23 |
+
"""
|
24 |
+
im = Image.fromarray(mask).convert("P")
|
25 |
+
im.putpalette(palette)
|
26 |
+
# exit(0)
|
27 |
+
return im
|
28 |
+
|
29 |
+
|
30 |
+
def get_mask_by_json(filename: str) -> np.ndarray:
|
31 |
+
"""
|
32 |
+
Convert json to mask
|
33 |
+
Args:
|
34 |
+
filename (str): path to json file
|
35 |
+
return:
|
36 |
+
mask (np.ndarray): numpy array of shape (H, W)
|
37 |
+
"""
|
38 |
+
|
39 |
+
json_file = json.load(open(filename))
|
40 |
+
img_height = json_file["imageHeight"]
|
41 |
+
img_width = json_file["imageWidth"]
|
42 |
+
mask = np.zeros((img_height, img_width), dtype="int8")
|
43 |
+
for shape in json_file["shapes"]:
|
44 |
+
label = int(shape["label"])
|
45 |
+
label -= 1
|
46 |
+
label = max(label, 0)
|
47 |
+
points = np.array(shape["points"]).astype(np.int32)
|
48 |
+
cv2.fillPoly(mask, [points], label)
|
49 |
+
return mask
|
50 |
+
|
51 |
+
|
52 |
+
def json_to_image(json_path, image_path):
|
53 |
+
"""
|
54 |
+
Convert json to image
|
55 |
+
Args:
|
56 |
+
json_path (str): path to json file
|
57 |
+
image_path (str): path to save image
|
58 |
+
return: None
|
59 |
+
"""
|
60 |
+
mask = get_mask_by_json(json_path)
|
61 |
+
palette_list = GrassDataset.METAINFO["palette"]
|
62 |
+
palette = []
|
63 |
+
for palette_item in palette_list:
|
64 |
+
palette.extend(palette_item)
|
65 |
+
color_mask = give_color_to_mask(mask, palette)
|
66 |
+
color_mask.save(image_path)
|
67 |
+
|
68 |
+
|
69 |
+
def create_dataset(
|
70 |
+
image_paths: List[str],
|
71 |
+
ann_paths: List[str],
|
72 |
+
phase: Literal["train", "val"],
|
73 |
+
output_dir: str,
|
74 |
+
):
|
75 |
+
"""
|
76 |
+
Args:
|
77 |
+
image_paths (List[str]): list of image paths
|
78 |
+
ann_paths (List[str]): list of annotation paths
|
79 |
+
phase (Literal["train", "val"]): train or val
|
80 |
+
output_dir (str): path to save dataset
|
81 |
+
Return:
|
82 |
+
None
|
83 |
+
"""
|
84 |
+
for image_path, ann_path in track(
|
85 |
+
zip(image_paths, ann_paths),
|
86 |
+
description=f"{phase} dataset",
|
87 |
+
total=len(image_paths),
|
88 |
+
):
|
89 |
+
ann_save_path = os.path.join(
|
90 |
+
output_dir,
|
91 |
+
"ann_dir",
|
92 |
+
phase,
|
93 |
+
os.path.basename(ann_path).replace(".json", ".png"),
|
94 |
+
)
|
95 |
+
|
96 |
+
# 将image复制到指定路径
|
97 |
+
new_image_path = os.path.join(
|
98 |
+
output_dir, "img_dir", phase, os.path.basename(image_path)
|
99 |
+
)
|
100 |
+
shutil.copy(image_path, new_image_path)
|
101 |
+
|
102 |
+
# 将ann保存到指定路径
|
103 |
+
json_to_image(ann_path, ann_save_path)
|
104 |
+
|
105 |
+
|
106 |
+
def split_dataset(
|
107 |
+
root_path: str,
|
108 |
+
output_path: str,
|
109 |
+
split_ratio: float = 0.8,
|
110 |
+
shuffle: bool = True,
|
111 |
+
seed: int = 42,
|
112 |
+
) -> None:
|
113 |
+
"""
|
114 |
+
Split a dataset into train, test, and validation sets.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
root_path (str): Path to the dataset. The dataset should be organized as follows:
|
118 |
+
dataset_path/
|
119 |
+
image1.tif
|
120 |
+
image2.tif
|
121 |
+
...
|
122 |
+
imageN.tif
|
123 |
+
label1.tif
|
124 |
+
label2.tif
|
125 |
+
...
|
126 |
+
labelN.tif
|
127 |
+
output_path (str): Path to the output directory where the split dataset will be saved.
|
128 |
+
split_ratio (float, optional): Ratio of the dataset to be used for training. Defaults to 0.8.
|
129 |
+
seed (int, optional): Seed for the random number generator. Defaults to 42.
|
130 |
+
"""
|
131 |
+
image_paths = glob(os.path.join(root_path, "*.tif"))
|
132 |
+
ann_paths = [filename.replace("tif", "json") for filename in image_paths]
|
133 |
+
assert len(image_paths) == len(
|
134 |
+
ann_paths
|
135 |
+
), "Number of images and annotations do not match"
|
136 |
+
print(f"images: {len(image_paths)}, annotations: {len(ann_paths)}")
|
137 |
+
|
138 |
+
image_train, image_test, ann_train, ann_test = train_test_split(
|
139 |
+
image_paths,
|
140 |
+
ann_paths,
|
141 |
+
train_size=split_ratio,
|
142 |
+
random_state=seed,
|
143 |
+
shuffle=shuffle,
|
144 |
+
)
|
145 |
+
print(f"train: {len(image_train)}, test: {len(image_test)}")
|
146 |
+
|
147 |
+
os.makedirs(os.path.join(output_path, "img_dir", "train"), exist_ok=True)
|
148 |
+
os.makedirs(os.path.join(output_path, "img_dir", "val"), exist_ok=True)
|
149 |
+
os.makedirs(os.path.join(output_path, "ann_dir", "train"), exist_ok=True)
|
150 |
+
os.makedirs(os.path.join(output_path, "ann_dir", "val"), exist_ok=True)
|
151 |
+
|
152 |
+
create_dataset(image_train, ann_train, "train", output_path)
|
153 |
+
create_dataset(image_test, ann_test, "val", output_path)
|
154 |
+
|
155 |
+
|
156 |
+
def main():
|
157 |
+
args = argparse.ArgumentParser()
|
158 |
+
args.add_argument("--root", type=str, default="data/raw_data")
|
159 |
+
args.add_argument("--output", type=str, default="data/grass")
|
160 |
+
args.add_argument("--split_ratio", type=float, default=0.8)
|
161 |
+
args.add_argument("--seed", type=int, default=42)
|
162 |
+
args.add_argument("--shuffle", type=bool, default=True)
|
163 |
+
args = args.parse_args()
|
164 |
+
|
165 |
+
root: str = args.root
|
166 |
+
output_path: str = args.output
|
167 |
+
split_ratio: float = args.split_ratio
|
168 |
+
seed: int = args.seed
|
169 |
+
shuffle: bool = args.shuffle
|
170 |
+
|
171 |
+
split_dataset(
|
172 |
+
root_path=root,
|
173 |
+
output_path=output_path,
|
174 |
+
split_ratio=split_ratio,
|
175 |
+
shuffle=shuffle,
|
176 |
+
seed=seed,
|
177 |
+
)
|
178 |
+
|
179 |
+
print("数据集划分完成")
|
180 |
+
|
181 |
+
|
182 |
+
if __name__ == "__main__":
|
183 |
+
|
184 |
+
# 使用示例 : python src/tools/split_dataset.py --root data/raw_data --output data/grass --split_ratio 0.8 --seed 42 --shuffle True
|
185 |
+
main()
|