XavierJiezou commited on
Commit
918db92
·
verified ·
1 Parent(s): 3a43a03

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. configs/dinov2/dinov2_upernet_water.py +13 -0
  2. configs/ktda/dinov2_b_frozen-fam-fmm.py +18 -0
  3. configs/ktda/dinov2_b_frozen-fam.py +13 -0
  4. configs/ktda/experiment_a.py +14 -0
  5. configs/ktda/experiment_aa.py +46 -0
  6. configs/ktda/experiment_k.py +14 -0
  7. configs/ktda/experiment_u.py +15 -0
  8. configs/ktda/experiment_v.py +26 -0
  9. configs/ktda/ktda_grass.py +19 -0
  10. configs/pspnet/pspnet_r101_water.py +15 -0
  11. configs/pspnet/pspnet_r50.py +13 -0
  12. configs/segformer/segformer_mit-b0_water.py +14 -0
  13. ktda/datasets/__init__.py +7 -0
  14. ktda/datasets/grass.py +55 -0
  15. ktda/datasets/l8_biome.py +29 -0
  16. ktda/models/__init__.py +4 -0
  17. ktda/models/__pycache__/__init__.cpython-311.pyc +0 -0
  18. ktda/models/adapter/__init__.py +4 -0
  19. ktda/models/adapter/__pycache__/__init__.cpython-311.pyc +0 -0
  20. ktda/models/adapter/__pycache__/fam.cpython-311.pyc +0 -0
  21. ktda/models/adapter/__pycache__/fmm.cpython-311.pyc +0 -0
  22. ktda/models/adapter/fam.py +37 -0
  23. ktda/models/adapter/fmm.py +109 -0
  24. ktda/models/segmentors/__pycache__/__init__.cpython-311.pyc +0 -0
  25. ktda/models/segmentors/__pycache__/distill_encoder_decoder.cpython-311.pyc +0 -0
  26. ktda/models/segmentors/distill_encoder_decoder.py +382 -0
  27. requirements/docs.txt +7 -0
  28. requirements/optional.txt +22 -0
  29. requirements/runtime.txt +5 -0
  30. tools/analysis_tools/analyze_logs.py +130 -0
  31. tools/analysis_tools/benchmark.py +121 -0
  32. tools/analysis_tools/confusion_matrix.py +197 -0
  33. tools/analysis_tools/get_flops.py +126 -0
  34. tools/analysis_tools/visualization_cam.py +127 -0
  35. tools/dataset_converters/chase_db1.py +89 -0
  36. tools/dataset_converters/cityscapes.py +56 -0
  37. tools/dataset_converters/coco_stuff10k.py +308 -0
  38. tools/dataset_converters/coco_stuff164k.py +265 -0
  39. tools/dataset_converters/hrf.py +112 -0
  40. tools/dataset_converters/isaid.py +246 -0
  41. tools/dataset_converters/levircd.py +99 -0
  42. tools/dataset_converters/loveda.py +73 -0
  43. tools/dataset_converters/nyu.py +89 -0
  44. tools/dataset_converters/pascal_context.py +87 -0
  45. tools/dataset_converters/potsdam.py +158 -0
  46. tools/dataset_converters/refuge.py +110 -0
  47. tools/dataset_converters/stare.py +167 -0
  48. tools/dataset_converters/synapse.py +155 -0
  49. tools/dataset_converters/voc_aug.py +92 -0
  50. tools/dataset_tools/create_dataset.py +185 -0
configs/dinov2/dinov2_upernet_water.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = [
2
+ "../_base_/models/dinov2_upernet.py",
3
+ "../_base_/datasets/water.py",
4
+ "../_base_/default_runtime.py",
5
+ "../_base_/schedules/water_schedule.py",
6
+ ]
7
+
8
+ data_preprocessor = dict(size=(512, 512))
9
+ model = dict(
10
+ data_preprocessor=data_preprocessor,
11
+ decode_head=dict(num_classes=6),
12
+ auxiliary_head=dict(num_classes=6)
13
+ )
configs/ktda/dinov2_b_frozen-fam-fmm.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = [
2
+ "../_base_/models/ktda.py",
3
+ "../_base_/datasets/grass.py",
4
+ "../_base_/default_runtime.py",
5
+ "../_base_/schedules/grass_schedule.py",
6
+ ]
7
+
8
+ data_preprocessor = dict(size=(256, 256))
9
+ model = dict(
10
+ data_preprocessor=data_preprocessor,
11
+ decode_head=dict(
12
+ num_classes=5,
13
+ ),
14
+ auxiliary_head=dict(
15
+ num_classes=5,
16
+ ),
17
+ fmm=dict(type="FMM", in_channels=[768, 768, 768, 768]),
18
+ )
configs/ktda/dinov2_b_frozen-fam.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = [
2
+ "../_base_/models/ktda.py",
3
+ "../_base_/datasets/grass.py",
4
+ "../_base_/default_runtime.py",
5
+ "../_base_/schedules/grass_schedule.py",
6
+ ]
7
+
8
+ data_preprocessor = dict(size=(256, 256))
9
+ model = dict(
10
+ data_preprocessor=data_preprocessor,
11
+ decode_head=dict(num_classes=5),
12
+ auxiliary_head=dict(num_classes=5)
13
+ )
configs/ktda/experiment_a.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = [
2
+ "../_base_/models/ktda.py",
3
+ "../_base_/datasets/grass.py",
4
+ "../_base_/default_runtime.py",
5
+ "../_base_/schedules/grass_schedule.py",
6
+ ]
7
+
8
+ data_preprocessor = dict(size=(256, 256))
9
+ model = dict(
10
+ student_training=False,
11
+ data_preprocessor=data_preprocessor,
12
+ decode_head=dict(num_classes=5),
13
+ auxiliary_head=dict(num_classes=5)
14
+ )
configs/ktda/experiment_aa.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = [
2
+ "../_base_/models/convnextv2_femto_vit_segformer_vegseg.py",
3
+ "../_base_/datasets/grass.py",
4
+ "../_base_/default_runtime.py",
5
+ "../_base_/schedules/grass_schedule.py",
6
+ ]
7
+
8
+ data_preprocessor = dict(size=(256, 256))
9
+ model = dict(
10
+ teach_backbone=dict(
11
+ type="mmpretrain.VisionTransformer",
12
+ arch="large",
13
+ frozen_stages=24,
14
+ img_size=256,
15
+ patch_size=14,
16
+ layer_scale_init_value=1e-5,
17
+ out_indices=(7, 11, 15, 23),
18
+ out_type="featmap",
19
+ init_cfg=dict(
20
+ type="Pretrained",
21
+ checkpoint="checkpoints/dinov2-large.pth",
22
+ prefix="backbone",
23
+ ),
24
+ ),
25
+ fam=dict(out_channels=1024),
26
+ decode_head=dict(in_channels=[1024, 1024, 1024, 1024], num_classes=5),
27
+ data_preprocessor=data_preprocessor,
28
+ auxiliary_head=[
29
+ dict(
30
+ type="FCNHead",
31
+ in_channels=1024,
32
+ in_index=i,
33
+ channels=256,
34
+ num_convs=1,
35
+ concat_input=False,
36
+ dropout_ratio=0.1,
37
+ num_classes=5,
38
+ norm_cfg=dict(type="SyncBN", requires_grad=True),
39
+ align_corners=False,
40
+ loss_decode=dict(
41
+ type="CrossEntropyLoss", use_sigmoid=False, loss_weight=0.4
42
+ ),
43
+ )
44
+ for i in range(4)
45
+ ],
46
+ )
configs/ktda/experiment_k.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = [
2
+ "../_base_/models/ktda.py",
3
+ "../_base_/datasets/grass.py",
4
+ "../_base_/default_runtime.py",
5
+ "../_base_/schedules/grass_schedule.py",
6
+ ]
7
+
8
+ data_preprocessor = dict(size=(256, 256))
9
+ model = dict(
10
+ data_preprocessor=data_preprocessor,
11
+ decode_head=dict(num_classes=5),
12
+ auxiliary_head=dict(num_classes=5),
13
+ fmm=dict(type="FMM", in_channels=[768, 768, 768, 768],mlp_nums=4),
14
+ )
configs/ktda/experiment_u.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = [
2
+ "../_base_/models/ktda.py",
3
+ "../_base_/datasets/grass.py",
4
+ "../_base_/default_runtime.py",
5
+ "../_base_/schedules/grass_schedule.py",
6
+ ]
7
+
8
+ data_preprocessor = dict(size=(256, 256))
9
+ model = dict(
10
+ data_preprocessor=data_preprocessor,
11
+ decode_head=dict(num_classes=5),
12
+ auxiliary_head=dict(num_classes=5),
13
+ neck=None,
14
+ fmm=dict(type="FMM", in_channels=[768, 768, 768, 768]),
15
+ )
configs/ktda/experiment_v.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = [
2
+ "../_base_/models/ktda.py",
3
+ "../_base_/datasets/grass.py",
4
+ "../_base_/default_runtime.py",
5
+ "../_base_/schedules/grass_schedule.py",
6
+ ]
7
+
8
+ data_preprocessor = dict(size=(256, 256))
9
+ model = dict(
10
+ data_preprocessor=data_preprocessor,
11
+ decode_head=dict(
12
+ _delete_=True,
13
+ type="SegformerHead",
14
+ in_channels=[768, 768, 768, 768],
15
+ in_index=[0, 1, 2, 3],
16
+ channels=256,
17
+ dropout_ratio=0.1,
18
+ num_classes=5,
19
+ norm_cfg=dict(type="SyncBN", requires_grad=True),
20
+ align_corners=False,
21
+ loss_decode=dict(type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0),
22
+ ),
23
+ auxiliary_head=dict(num_classes=5),
24
+ neck=None,
25
+ fmm=dict(type="FMM", in_channels=[768, 768, 768, 768]),
26
+ )
configs/ktda/ktda_grass.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = [
2
+ "../_base_/models/ktda.py",
3
+ "../_base_/datasets/grass.py",
4
+ "../_base_/default_runtime.py",
5
+ "../_base_/schedules/grass_schedule.py",
6
+ ]
7
+
8
+ data_preprocessor = dict(size=(256, 256))
9
+ model = dict(
10
+ data_preprocessor=data_preprocessor,
11
+ decode_head=dict(num_classes=5),
12
+ auxiliary_head=dict(num_classes=5),
13
+ fmm=dict(
14
+ type="FMM",
15
+ in_channels=[768, 768, 768, 768],
16
+ model_type="vitBlock",
17
+ mlp_nums=4,
18
+ ),
19
+ )
configs/pspnet/pspnet_r101_water.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = [
2
+ "../_base_/models/pspnet_r50-d8.py",
3
+ "../_base_/datasets/water.py",
4
+ "../_base_/default_runtime.py",
5
+ "../_base_/schedules/water_schedule.py",
6
+ ]
7
+
8
+ data_preprocessor = dict(size=(512, 512))
9
+ model = dict(
10
+ data_preprocessor=data_preprocessor,
11
+ pretrained='open-mmlab://resnet101_v1c',
12
+ backbone=dict(depth=101),
13
+ decode_head=dict(num_classes=6),
14
+ auxiliary_head=dict(num_classes=6)
15
+ )
configs/pspnet/pspnet_r50.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = [
2
+ "../_base_/models/pspnet_r50-d8.py",
3
+ "../_base_/datasets/grass.py",
4
+ "../_base_/default_runtime.py",
5
+ "../_base_/schedules/grass_schedule.py",
6
+ ]
7
+
8
+ data_preprocessor = dict(size=(256, 256))
9
+ model = dict(
10
+ data_preprocessor=data_preprocessor,
11
+ decode_head=dict(num_classes=5),
12
+ auxiliary_head=dict(num_classes=5)
13
+ )
configs/segformer/segformer_mit-b0_water.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = [
2
+ "../_base_/models/segformer_mit-b0.py",
3
+ "../_base_/datasets/water.py",
4
+ "../_base_/default_runtime.py",
5
+ "../_base_/schedules/water_schedule.py",
6
+ ]
7
+
8
+ data_preprocessor = dict(size=(512, 512))
9
+ checkpoint = "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segformer/mit_b0_20220624-7e0fe6dd.pth" # noqa
10
+ model = dict(
11
+ data_preprocessor=data_preprocessor,
12
+ backbone=dict(init_cfg=dict(type="Pretrained", checkpoint=checkpoint)),
13
+ decode_head=dict(num_classes=6),
14
+ )
ktda/datasets/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .grass import GrassDataset
2
+ from .l8_biome import L8BIOMEDataset
3
+
4
+ __all__ = [
5
+ "GrassDataset",
6
+ "L8BIOMEDataset"
7
+ ]
ktda/datasets/grass.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os.path as osp
3
+ from typing import List
4
+
5
+ import mmengine.fileio as fileio
6
+
7
+ from mmseg.registry import DATASETS
8
+ from mmseg.datasets import BaseSegDataset
9
+
10
+
11
+ @DATASETS.register_module()
12
+ class GrassDataset(BaseSegDataset):
13
+ """grass segmentation dataset. The file structure should be.
14
+
15
+ .. code-block:: none
16
+
17
+ ├── data
18
+ │ ├── grass
19
+ │ │ ├── img_dir
20
+ │ │ │ ├── train
21
+ │ │ │ │ ├──0.tif
22
+ │ │ │ │ ├──...
23
+ │ │ │ ├── val
24
+ │ │ │ │ ├──9.tif
25
+ │ │ │ │ ├──...
26
+ │ │ ├── ann_dir
27
+ │ │ │ ├── train
28
+ │ │ │ │ ├──0.png
29
+ │ │ │ │ ├──...
30
+ │ │ │ ├── val
31
+ │ │ │ │ ├──9.png
32
+ │ │ │ │ ├──...
33
+ """
34
+
35
+ METAINFO = dict(
36
+ classes=("low", "middle-low", "middle", "middle-high", "high"),
37
+ palette=[
38
+ [185, 101, 71],
39
+ [248, 202, 155],
40
+ [211, 232, 158],
41
+ [138, 191, 104],
42
+ [92, 144, 77],
43
+ ],
44
+ )
45
+
46
+ def __init__(self,
47
+ img_suffix='.tif',
48
+ seg_map_suffix='.png',
49
+ reduce_zero_label=False,
50
+ **kwargs) -> None:
51
+ super().__init__(
52
+ img_suffix=img_suffix,
53
+ seg_map_suffix=seg_map_suffix,
54
+ reduce_zero_label=reduce_zero_label,
55
+ **kwargs)
ktda/datasets/l8_biome.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmseg.registry import DATASETS
2
+ from mmseg.datasets import BaseSegDataset
3
+
4
+
5
+ @DATASETS.register_module()
6
+ class L8BIOMEDataset(BaseSegDataset):
7
+ METAINFO = dict(
8
+ classes=("Clear", "Cloud Shadow", "Thin Cloud", "Cloud"),
9
+ palette=[
10
+ [79, 253, 199],
11
+ [221, 53, 223],
12
+ [251, 255, 41],
13
+ [77, 2, 115],
14
+ ],
15
+ )
16
+
17
+ def __init__(
18
+ self,
19
+ img_suffix=".png",
20
+ seg_map_suffix=".png",
21
+ reduce_zero_label=False,
22
+ **kwargs
23
+ ) -> None:
24
+ super().__init__(
25
+ img_suffix=img_suffix,
26
+ seg_map_suffix=seg_map_suffix,
27
+ reduce_zero_label=reduce_zero_label,
28
+ **kwargs
29
+ )
ktda/models/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .segmentors import DistillEncoderDecoder
2
+ from .adapter import FAM,FMM
3
+
4
+ __all__ = ["DistillEncoderDecoder", "FAM","FMM"]
ktda/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (338 Bytes). View file
 
ktda/models/adapter/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .fam import FAM
2
+ from .fmm import FMM
3
+
4
+ __all__ = ["FAM", "FMM"]
ktda/models/adapter/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (288 Bytes). View file
 
ktda/models/adapter/__pycache__/fam.cpython-311.pyc ADDED
Binary file (2.86 kB). View file
 
ktda/models/adapter/__pycache__/fmm.cpython-311.pyc ADDED
Binary file (5.88 kB). View file
 
ktda/models/adapter/fam.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmseg.registry import MODELS
2
+ from mmengine.model import BaseModule
3
+ from torch import nn as nn
4
+ from torch.nn import functional as F
5
+ from timm.models.layers import trunc_normal_
6
+
7
+
8
+ @MODELS.register_module()
9
+ class FAM(BaseModule):
10
+ def __init__(self, in_channels, out_channels, output_size,init_cfg=None):
11
+ super().__init__(init_cfg)
12
+ self.convert = nn.ModuleList()
13
+ self.output_size = output_size
14
+ if isinstance(out_channels, int):
15
+ out_channels = [out_channels] * len(in_channels)
16
+ for in_channel, out_channel in zip(in_channels, out_channels):
17
+ self.convert.append(
18
+ nn.Conv2d(in_channel, out_channel, kernel_size=1),
19
+ )
20
+
21
+ self.apply(self._init_weights)
22
+
23
+ def _init_weights(self, m):
24
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
25
+ trunc_normal_(m.weight, std=.02)
26
+ nn.init.constant_(m.bias, 0)
27
+
28
+
29
+ def forward(self, inputs):
30
+ outs = []
31
+ for index, x in enumerate(inputs):
32
+ x = self.convert[index](x)
33
+ x = F.interpolate(
34
+ x, size=(self.output_size,self.output_size), align_corners=False, mode="bilinear"
35
+ )
36
+ outs.append(x)
37
+ return tuple(outs)
ktda/models/adapter/fmm.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmseg.registry import MODELS
2
+ from mmengine.model import BaseModule
3
+ from torch import nn as nn
4
+ from torch.nn import functional as F
5
+ from typing import Callable, Optional
6
+ from torch import Tensor
7
+ from timm.models.layers import trunc_normal_
8
+ from timm.models.vision_transformer import Block as TransformerBlock
9
+
10
+
11
+ class Mlp(nn.Module):
12
+ def __init__(
13
+ self,
14
+ in_features: int,
15
+ hidden_features: Optional[int] = None,
16
+ out_features: Optional[int] = None,
17
+ act_layer: Callable[..., nn.Module] = nn.GELU,
18
+ drop: float = 0.0,
19
+ bias: bool = True,
20
+ ) -> None:
21
+ super().__init__()
22
+ out_features = out_features or in_features
23
+ hidden_features = hidden_features or in_features
24
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
25
+ self.act = act_layer()
26
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
27
+ self.drop = nn.Dropout(drop)
28
+
29
+ def forward(self, x: Tensor) -> Tensor:
30
+ x = self.fc1(x)
31
+ x = self.act(x)
32
+ x = self.drop(x)
33
+ x = self.fc2(x)
34
+ x = self.drop(x)
35
+ return x
36
+
37
+
38
+ @MODELS.register_module()
39
+ class FMM(BaseModule):
40
+ def __init__(
41
+ self,
42
+ in_channels,
43
+ rank_dim=4,
44
+ mlp_nums=1,
45
+ model_type="mlp",
46
+ num_heads=8,
47
+ mlp_ratio=4,
48
+ qkv_bias=True,
49
+ qk_norm=False,
50
+ init_values=None,
51
+ proj_drop_rate: float = 0.0,
52
+ attn_drop_rate: float = 0.0,
53
+ init_cfg=None,
54
+ ):
55
+ super().__init__(init_cfg)
56
+ self.adapters = nn.ModuleList()
57
+ if model_type == "mlp":
58
+ for in_channel in in_channels:
59
+ mlp_list = []
60
+ for _ in range(mlp_nums):
61
+ mlp_list.append(
62
+ Mlp(
63
+ in_channel,
64
+ hidden_features=in_channel // rank_dim,
65
+ out_features=in_channel,
66
+ )
67
+ )
68
+ mlp_model = nn.Sequential(*mlp_list)
69
+ self.adapters.append(mlp_model)
70
+
71
+ elif model_type == "vitBlock":
72
+ for in_channel in in_channels:
73
+ model_list = []
74
+ for _ in range(mlp_nums):
75
+ model_list.append(
76
+ TransformerBlock(
77
+ in_channel,
78
+ num_heads=num_heads,
79
+ mlp_ratio=mlp_ratio,
80
+ qkv_bias=qkv_bias,
81
+ qk_norm=qk_norm,
82
+ init_values=init_values,
83
+ proj_drop=proj_drop_rate,
84
+ attn_drop=attn_drop_rate,
85
+ )
86
+ )
87
+ self.adapters.append(nn.Sequential(*model_list))
88
+
89
+ else:
90
+ raise ValueError(f"model type must in ['mlp','vitBlock'],actually is {model_type}")
91
+
92
+ self.apply(self._init_weights)
93
+
94
+ def _init_weights(self, m):
95
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
96
+ trunc_normal_(m.weight, std=0.02)
97
+ nn.init.constant_(m.bias, 0)
98
+
99
+ def forward(self, inputs):
100
+ outs = []
101
+ for index, x in enumerate(inputs):
102
+ B, C, H, W = x.shape
103
+ x = x.permute(0, 2, 3, 1)
104
+ x = x.reshape(B, -1, C)
105
+ x = self.adapters[index](x)
106
+ x = x.reshape(B, H, W, C)
107
+ x = x.permute(0, 3, 1, 2)
108
+ outs.append(x)
109
+ return tuple(outs)
ktda/models/segmentors/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (274 Bytes). View file
 
ktda/models/segmentors/__pycache__/distill_encoder_decoder.cpython-311.pyc ADDED
Binary file (19.8 kB). View file
 
ktda/models/segmentors/distill_encoder_decoder.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import logging
3
+ from typing import List, Optional
4
+
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from mmengine.logging import print_log
8
+ from torch import Tensor
9
+
10
+ from mmseg.registry import MODELS
11
+ from mmseg.utils import (
12
+ ConfigType,
13
+ OptConfigType,
14
+ OptMultiConfig,
15
+ OptSampleList,
16
+ SampleList,
17
+ add_prefix,
18
+ )
19
+ from mmseg.models import BaseSegmentor
20
+
21
+
22
+ @MODELS.register_module()
23
+ class DistillEncoderDecoder(BaseSegmentor):
24
+
25
+ def __init__(
26
+ self,
27
+ backbone: ConfigType,
28
+ teach_backbone: ConfigType,
29
+ decode_head: ConfigType,
30
+ neck: OptConfigType = None,
31
+ auxiliary_head: OptConfigType = None,
32
+ fam: OptConfigType = None,
33
+ fmm: OptConfigType = None,
34
+ train_cfg: OptConfigType = None,
35
+ test_cfg: OptConfigType = None,
36
+ data_preprocessor: OptConfigType = None,
37
+ pretrained: Optional[str] = None,
38
+ student_training=True,
39
+ temperature=1.0,
40
+ alpha=0.5,
41
+ fuse=False,
42
+ init_cfg: OptMultiConfig = None,
43
+ ):
44
+ super().__init__(data_preprocessor=data_preprocessor, init_cfg=init_cfg)
45
+
46
+ self.temperature = temperature
47
+ self.alpha = alpha
48
+ self.student_training = student_training
49
+ self.fuse = fuse
50
+
51
+ if pretrained is not None:
52
+ assert (
53
+ backbone.get("pretrained") is None
54
+ ), "both backbone and segmentor set pretrained weight"
55
+ assert (
56
+ teach_backbone.get("pretrained") is None
57
+ ), "both teach backbone and segmentor set pretrained weight"
58
+ backbone.pretrained = pretrained
59
+ teach_backbone.pretrained = pretrained
60
+ self.backbone = MODELS.build(backbone)
61
+ self.teach_backbone = MODELS.build(teach_backbone)
62
+ if neck is not None:
63
+ self.neck = MODELS.build(neck)
64
+
65
+ self.fam = nn.Identity()
66
+ self.fmm = nn.Identity()
67
+ if fam is not None:
68
+ self.fam = MODELS.build(fam)
69
+ if fmm is not None:
70
+ self.fmm = MODELS.build(fmm)
71
+ self._init_decode_head(decode_head)
72
+ self._init_auxiliary_head(auxiliary_head)
73
+
74
+ self.train_cfg = train_cfg
75
+ self.test_cfg = test_cfg
76
+
77
+ assert self.with_decode_head
78
+
79
+ def _init_decode_head(self, decode_head: ConfigType) -> None:
80
+ """Initialize ``decode_head``"""
81
+ self.decode_head = MODELS.build(decode_head)
82
+ self.align_corners = self.decode_head.align_corners
83
+ self.num_classes = self.decode_head.num_classes
84
+ self.out_channels = self.decode_head.out_channels
85
+
86
+ def _init_auxiliary_head(self, auxiliary_head: ConfigType) -> None:
87
+ """Initialize ``auxiliary_head``"""
88
+ if auxiliary_head is not None:
89
+ if isinstance(auxiliary_head, list):
90
+ self.auxiliary_head = nn.ModuleList()
91
+ for head_cfg in auxiliary_head:
92
+ self.auxiliary_head.append(MODELS.build(head_cfg))
93
+ else:
94
+ self.auxiliary_head = MODELS.build(auxiliary_head)
95
+
96
+ def fuse_features(self,features):
97
+ x = features[0]
98
+ for index,feature in enumerate(features):
99
+ if index == 0:
100
+ continue
101
+ x += feature
102
+ x = [x]
103
+ return tuple(x)
104
+
105
+ def extract_feat(self, inputs: Tensor) -> List[Tensor]:
106
+ """Extract features from images."""
107
+ x = self.backbone(inputs)
108
+ x = self.fam(x)
109
+ if self.fuse:
110
+ x = self.fuse_features(x)
111
+ if self.with_neck:
112
+ x = self.neck(x)
113
+ x = self.fmm(x)
114
+ return x
115
+
116
+ def encode_decode(self, inputs: Tensor, batch_img_metas: List[dict]) -> Tensor:
117
+ """Encode images with backbone and decode into a semantic segmentation
118
+ map of the same size as input."""
119
+ x = self.extract_feat(inputs)
120
+ seg_logits = self.decode_head.predict(x, batch_img_metas, self.test_cfg)
121
+
122
+ return seg_logits
123
+
124
+ def _decode_head_forward_train(
125
+ self, inputs: List[Tensor], data_samples: SampleList
126
+ ) -> dict:
127
+ """Run forward function and calculate loss for decode head in
128
+ training."""
129
+ losses = dict()
130
+ loss_decode = self.decode_head.loss(inputs, data_samples, self.train_cfg)
131
+
132
+ losses.update(add_prefix(loss_decode, "decode"))
133
+ return losses
134
+
135
+ def _auxiliary_head_forward_train(
136
+ self, inputs: List[Tensor], data_samples: SampleList
137
+ ) -> dict:
138
+ """Run forward function and calculate loss for auxiliary head in
139
+ training."""
140
+ losses = dict()
141
+ if isinstance(self.auxiliary_head, nn.ModuleList):
142
+ for idx, aux_head in enumerate(self.auxiliary_head):
143
+ loss_aux = aux_head.loss(inputs, data_samples, self.train_cfg)
144
+ for key in loss_aux.keys():
145
+ loss_aux[key] = loss_aux[key] / len(self.auxiliary_head)
146
+ losses.update(add_prefix(loss_aux, f"aux_{idx}"))
147
+ else:
148
+ loss_aux = self.auxiliary_head.loss(inputs, data_samples, self.train_cfg)
149
+ losses.update(add_prefix(loss_aux, "aux"))
150
+
151
+ return losses
152
+
153
+ def calculate_diltill_loss(self, inputs):
154
+ student_feats = self.backbone(inputs)
155
+ student_feats = self.fam(student_feats)
156
+ teach_feats = self.teach_backbone(inputs)
157
+
158
+ if self.fuse:
159
+ student_feats = self.fuse_features(student_feats)
160
+ teach_feats = self.fuse_features(teach_feats)
161
+
162
+ total_loss = 0.0
163
+ for student_feat, teach_feat in zip(student_feats, teach_feats):
164
+ student_prob = F.softmax(student_feat / self.temperature, dim=-1)
165
+ teach_prob = F.softmax(teach_feat / self.temperature, dim=-1)
166
+ kl_loss = F.kl_div(
167
+ student_prob.log(), teach_prob, reduction="batchmean"
168
+ ) * (self.temperature**2)
169
+ mse_loss = F.mse_loss(student_feat, teach_feat, reduction="mean")
170
+ loss = self.alpha * kl_loss + (1 - self.alpha) * mse_loss
171
+ total_loss += loss
172
+
173
+ avg_loss = total_loss / len(student_feats)
174
+ if self.alpha == 0:
175
+ avg_loss = avg_loss * 0.5
176
+ return avg_loss
177
+
178
+ def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
179
+ """Calculate losses from a batch of inputs and data samples.
180
+
181
+ Args:
182
+ inputs (Tensor): Input images.
183
+ data_samples (list[:obj:`SegDataSample`]): The seg data samples.
184
+ It usually includes information such as `metainfo` and
185
+ `gt_sem_seg`.
186
+
187
+ Returns:
188
+ dict[str, Tensor]: a dictionary of loss components
189
+ """
190
+
191
+ x = self.extract_feat(inputs)
192
+
193
+ losses = dict()
194
+
195
+ loss_decode = self._decode_head_forward_train(x, data_samples)
196
+ losses.update(loss_decode)
197
+ if self.student_training:
198
+ losses["distill_loss"] = self.calculate_diltill_loss(inputs)
199
+ if self.with_auxiliary_head:
200
+ loss_aux = self._auxiliary_head_forward_train(x, data_samples)
201
+ losses.update(loss_aux)
202
+
203
+ return losses
204
+
205
+ def predict(self, inputs: Tensor, data_samples: OptSampleList = None) -> SampleList:
206
+ """Predict results from a batch of inputs and data samples with post-
207
+ processing.
208
+
209
+ Args:
210
+ inputs (Tensor): Inputs with shape (N, C, H, W).
211
+ data_samples (List[:obj:`SegDataSample`], optional): The seg data
212
+ samples. It usually includes information such as `metainfo`
213
+ and `gt_sem_seg`.
214
+
215
+ Returns:
216
+ list[:obj:`SegDataSample`]: Segmentation results of the
217
+ input images. Each SegDataSample usually contain:
218
+
219
+ - ``pred_sem_seg``(PixelData): Prediction of semantic segmentation.
220
+ - ``seg_logits``(PixelData): Predicted logits of semantic
221
+ segmentation before normalization.
222
+ """
223
+ if data_samples is not None:
224
+ batch_img_metas = [data_sample.metainfo for data_sample in data_samples]
225
+ else:
226
+ batch_img_metas = [
227
+ dict(
228
+ ori_shape=inputs.shape[2:],
229
+ img_shape=inputs.shape[2:],
230
+ pad_shape=inputs.shape[2:],
231
+ padding_size=[0, 0, 0, 0],
232
+ )
233
+ ] * inputs.shape[0]
234
+
235
+ seg_logits = self.inference(inputs, batch_img_metas)
236
+
237
+ return self.postprocess_result(seg_logits, data_samples)
238
+
239
+ def _forward(self, inputs: Tensor, data_samples: OptSampleList = None) -> Tensor:
240
+ """Network forward process.
241
+
242
+ Args:
243
+ inputs (Tensor): Inputs with shape (N, C, H, W).
244
+ data_samples (List[:obj:`SegDataSample`]): The seg
245
+ data samples. It usually includes information such
246
+ as `metainfo` and `gt_sem_seg`.
247
+
248
+ Returns:
249
+ Tensor: Forward output of model without any post-processes.
250
+ """
251
+ x = self.extract_feat(inputs)
252
+ return self.decode_head.forward(x)
253
+
254
+ def slide_inference(self, inputs: Tensor, batch_img_metas: List[dict]) -> Tensor:
255
+ """Inference by sliding-window with overlap.
256
+
257
+ If h_crop > h_img or w_crop > w_img, the small patch will be used to
258
+ decode without padding.
259
+
260
+ Args:
261
+ inputs (tensor): the tensor should have a shape NxCxHxW,
262
+ which contains all images in the batch.
263
+ batch_img_metas (List[dict]): List of image metainfo where each may
264
+ also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
265
+ 'ori_shape', and 'pad_shape'.
266
+ For details on the values of these keys see
267
+ `mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
268
+
269
+ Returns:
270
+ Tensor: The segmentation results, seg_logits from model of each
271
+ input image.
272
+ """
273
+
274
+ h_stride, w_stride = self.test_cfg.stride
275
+ h_crop, w_crop = self.test_cfg.crop_size
276
+ batch_size, _, h_img, w_img = inputs.size()
277
+ out_channels = self.out_channels
278
+ h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
279
+ w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
280
+ preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img))
281
+ count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img))
282
+ for h_idx in range(h_grids):
283
+ for w_idx in range(w_grids):
284
+ y1 = h_idx * h_stride
285
+ x1 = w_idx * w_stride
286
+ y2 = min(y1 + h_crop, h_img)
287
+ x2 = min(x1 + w_crop, w_img)
288
+ y1 = max(y2 - h_crop, 0)
289
+ x1 = max(x2 - w_crop, 0)
290
+ crop_img = inputs[:, :, y1:y2, x1:x2]
291
+ # change the image shape to patch shape
292
+ batch_img_metas[0]["img_shape"] = crop_img.shape[2:]
293
+ # the output of encode_decode is seg logits tensor map
294
+ # with shape [N, C, H, W]
295
+ crop_seg_logit = self.encode_decode(crop_img, batch_img_metas)
296
+ preds += F.pad(
297
+ crop_seg_logit,
298
+ (
299
+ int(x1),
300
+ int(preds.shape[3] - x2),
301
+ int(y1),
302
+ int(preds.shape[2] - y2),
303
+ ),
304
+ )
305
+
306
+ count_mat[:, :, y1:y2, x1:x2] += 1
307
+ assert (count_mat == 0).sum() == 0
308
+ seg_logits = preds / count_mat
309
+
310
+ return seg_logits
311
+
312
+ def whole_inference(self, inputs: Tensor, batch_img_metas: List[dict]) -> Tensor:
313
+ """Inference with full image.
314
+
315
+ Args:
316
+ inputs (Tensor): The tensor should have a shape NxCxHxW, which
317
+ contains all images in the batch.
318
+ batch_img_metas (List[dict]): List of image metainfo where each may
319
+ also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
320
+ 'ori_shape', and 'pad_shape'.
321
+ For details on the values of these keys see
322
+ `mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
323
+
324
+ Returns:
325
+ Tensor: The segmentation results, seg_logits from model of each
326
+ input image.
327
+ """
328
+
329
+ seg_logits = self.encode_decode(inputs, batch_img_metas)
330
+
331
+ return seg_logits
332
+
333
+ def inference(self, inputs: Tensor, batch_img_metas: List[dict]) -> Tensor:
334
+ """Inference with slide/whole style.
335
+
336
+ Args:
337
+ inputs (Tensor): The input image of shape (N, 3, H, W).
338
+ batch_img_metas (List[dict]): List of image metainfo where each may
339
+ also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
340
+ 'ori_shape', 'pad_shape', and 'padding_size'.
341
+ For details on the values of these keys see
342
+ `mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
343
+
344
+ Returns:
345
+ Tensor: The segmentation results, seg_logits from model of each
346
+ input image.
347
+ """
348
+ assert self.test_cfg.get("mode", "whole") in ["slide", "whole"], (
349
+ f'Only "slide" or "whole" test mode are supported, but got '
350
+ f'{self.test_cfg["mode"]}.'
351
+ )
352
+ ori_shape = batch_img_metas[0]["ori_shape"]
353
+ if not all(_["ori_shape"] == ori_shape for _ in batch_img_metas):
354
+ print_log(
355
+ "Image shapes are different in the batch.",
356
+ logger="current",
357
+ level=logging.WARN,
358
+ )
359
+ if self.test_cfg.mode == "slide":
360
+ seg_logit = self.slide_inference(inputs, batch_img_metas)
361
+ else:
362
+ seg_logit = self.whole_inference(inputs, batch_img_metas)
363
+
364
+ return seg_logit
365
+
366
+ def aug_test(self, inputs, batch_img_metas, rescale=True):
367
+ """Test with augmentations.
368
+
369
+ Only rescale=True is supported.
370
+ """
371
+ # aug_test rescale all imgs back to ori_shape for now
372
+ assert rescale
373
+ # to save memory, we get augmented seg logit inplace
374
+ seg_logit = self.inference(inputs[0], batch_img_metas[0], rescale)
375
+ for i in range(1, len(inputs)):
376
+ cur_seg_logit = self.inference(inputs[i], batch_img_metas[i], rescale)
377
+ seg_logit += cur_seg_logit
378
+ seg_logit /= len(inputs)
379
+ seg_pred = seg_logit.argmax(dim=1)
380
+ # unravel batch dim
381
+ seg_pred = list(seg_pred)
382
+ return seg_pred
requirements/docs.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ docutils==0.16.0
2
+ myst-parser
3
+ -e git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
4
+ sphinx==4.0.2
5
+ sphinx_copybutton
6
+ sphinx_markdown_tables
7
+ urllib3<2.0.0
requirements/optional.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cityscapesscripts
2
+ -e git+https://github.com/openai/CLIP.git@main#egg=clip
3
+
4
+ # for vpd model
5
+ diffusers
6
+ einops==0.3.0
7
+ imageio==2.9.0
8
+ imageio-ffmpeg==0.4.2
9
+ invisible-watermark
10
+ kornia==0.6
11
+ -e git+https://github.com/CompVis/stable-diffusion@21f890f#egg=latent-diffusion
12
+ nibabel
13
+ omegaconf==2.1.1
14
+ pudb==2019.2
15
+ pytorch-lightning==1.4.2
16
+ streamlit>=0.73.1
17
+ -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
18
+ test-tube>=0.7.5
19
+ timm
20
+ torch-fidelity==0.3.0
21
+ torchmetrics==0.6.0
22
+ transformers==4.19.2
requirements/runtime.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ matplotlib
2
+ numpy
3
+ packaging
4
+ prettytable
5
+ scipy
tools/analysis_tools/analyze_logs.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ """Modified from https://github.com/open-
3
+ mmlab/mmdetection/blob/master/tools/analysis_tools/analyze_logs.py."""
4
+ import argparse
5
+ import json
6
+ from collections import defaultdict
7
+
8
+ import matplotlib.pyplot as plt
9
+ import seaborn as sns
10
+
11
+
12
+ def plot_curve(log_dicts, args):
13
+ if args.backend is not None:
14
+ plt.switch_backend(args.backend)
15
+ sns.set_style(args.style)
16
+ # if legend is None, use {filename}_{key} as legend
17
+ legend = args.legend
18
+ if legend is None:
19
+ legend = []
20
+ for json_log in args.json_logs:
21
+ for metric in args.keys:
22
+ legend.append(f'{json_log}_{metric}')
23
+ assert len(legend) == (len(args.json_logs) * len(args.keys))
24
+ metrics = args.keys
25
+
26
+ num_metrics = len(metrics)
27
+ for i, log_dict in enumerate(log_dicts):
28
+ epochs = list(log_dict.keys())
29
+ for j, metric in enumerate(metrics):
30
+ print(f'plot curve of {args.json_logs[i]}, metric is {metric}')
31
+ plot_epochs = []
32
+ plot_iters = []
33
+ plot_values = []
34
+ # In some log files exist lines of validation,
35
+ # `mode` list is used to only collect iter number
36
+ # of training line.
37
+ for epoch in epochs:
38
+ epoch_logs = log_dict[epoch]
39
+ if metric not in epoch_logs.keys():
40
+ continue
41
+ if metric in ['mIoU', 'mAcc', 'aAcc']:
42
+ plot_epochs.append(epoch)
43
+ plot_values.append(epoch_logs[metric][0])
44
+ else:
45
+ for idx in range(len(epoch_logs[metric])):
46
+ plot_iters.append(epoch_logs['step'][idx])
47
+ plot_values.append(epoch_logs[metric][idx])
48
+ ax = plt.gca()
49
+ label = legend[i * num_metrics + j]
50
+ if metric in ['mIoU', 'mAcc', 'aAcc']:
51
+ ax.set_xticks(plot_epochs)
52
+ plt.xlabel('step')
53
+ plt.plot(plot_epochs, plot_values, label=label, marker='o')
54
+ else:
55
+ plt.xlabel('iter')
56
+ plt.plot(plot_iters, plot_values, label=label, linewidth=0.5)
57
+ plt.legend()
58
+ if args.title is not None:
59
+ plt.title(args.title)
60
+ if args.out is None:
61
+ plt.show()
62
+ else:
63
+ print(f'save curve to: {args.out}')
64
+ plt.savefig(args.out)
65
+ plt.cla()
66
+
67
+
68
+ def parse_args():
69
+ parser = argparse.ArgumentParser(description='Analyze Json Log')
70
+ parser.add_argument(
71
+ 'json_logs',
72
+ type=str,
73
+ nargs='+',
74
+ help='path of train log in json format')
75
+ parser.add_argument(
76
+ '--keys',
77
+ type=str,
78
+ nargs='+',
79
+ default=['mIoU'],
80
+ help='the metric that you want to plot')
81
+ parser.add_argument('--title', type=str, help='title of figure')
82
+ parser.add_argument(
83
+ '--legend',
84
+ type=str,
85
+ nargs='+',
86
+ default=None,
87
+ help='legend of each plot')
88
+ parser.add_argument(
89
+ '--backend', type=str, default=None, help='backend of plt')
90
+ parser.add_argument(
91
+ '--style', type=str, default='dark', help='style of plt')
92
+ parser.add_argument('--out', type=str, default=None)
93
+ args = parser.parse_args()
94
+ return args
95
+
96
+
97
+ def load_json_logs(json_logs):
98
+ # load and convert json_logs to log_dict, key is step, value is a sub dict
99
+ # keys of sub dict is different metrics
100
+ # value of sub dict is a list of corresponding values of all iterations
101
+ log_dicts = [dict() for _ in json_logs]
102
+ prev_step = 0
103
+ for json_log, log_dict in zip(json_logs, log_dicts):
104
+ with open(json_log) as log_file:
105
+ for line in log_file:
106
+ log = json.loads(line.strip())
107
+ # the final step in json file is 0.
108
+ if 'step' in log and log['step'] != 0:
109
+ step = log['step']
110
+ prev_step = step
111
+ else:
112
+ step = prev_step
113
+ if step not in log_dict:
114
+ log_dict[step] = defaultdict(list)
115
+ for k, v in log.items():
116
+ log_dict[step][k].append(v)
117
+ return log_dicts
118
+
119
+
120
+ def main():
121
+ args = parse_args()
122
+ json_logs = args.json_logs
123
+ for json_log in json_logs:
124
+ assert json_log.endswith('.json')
125
+ log_dicts = load_json_logs(json_logs)
126
+ plot_curve(log_dicts, args)
127
+
128
+
129
+ if __name__ == '__main__':
130
+ main()
tools/analysis_tools/benchmark.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import argparse
3
+ import os.path as osp
4
+ import time
5
+
6
+ import numpy as np
7
+ import torch
8
+ from mmengine import Config
9
+ from mmengine.fileio import dump
10
+ from mmengine.model.utils import revert_sync_batchnorm
11
+ from mmengine.registry import init_default_scope
12
+ from mmengine.runner import Runner, load_checkpoint
13
+ from mmengine.utils import mkdir_or_exist
14
+
15
+ from mmseg.registry import MODELS
16
+
17
+
18
+ def parse_args():
19
+ parser = argparse.ArgumentParser(description='MMSeg benchmark a model')
20
+ parser.add_argument('config', help='test config file path')
21
+ parser.add_argument('checkpoint', help='checkpoint file')
22
+ parser.add_argument(
23
+ '--log-interval', type=int, default=50, help='interval of logging')
24
+ parser.add_argument(
25
+ '--work-dir',
26
+ help=('if specified, the results will be dumped '
27
+ 'into the directory as json'))
28
+ parser.add_argument('--repeat-times', type=int, default=1)
29
+ args = parser.parse_args()
30
+ return args
31
+
32
+
33
+ def main():
34
+ args = parse_args()
35
+ cfg = Config.fromfile(args.config)
36
+
37
+ init_default_scope(cfg.get('default_scope', 'mmseg'))
38
+
39
+ timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
40
+ if args.work_dir is not None:
41
+ mkdir_or_exist(osp.abspath(args.work_dir))
42
+ json_file = osp.join(args.work_dir, f'fps_{timestamp}.json')
43
+ else:
44
+ # use config filename as default work_dir if cfg.work_dir is None
45
+ work_dir = osp.join('./work_dirs',
46
+ osp.splitext(osp.basename(args.config))[0])
47
+ mkdir_or_exist(osp.abspath(work_dir))
48
+ json_file = osp.join(work_dir, f'fps_{timestamp}.json')
49
+
50
+ repeat_times = args.repeat_times
51
+ # set cudnn_benchmark
52
+ torch.backends.cudnn.benchmark = False
53
+ cfg.model.pretrained = None
54
+
55
+ benchmark_dict = dict(config=args.config, unit='img / s')
56
+ overall_fps_list = []
57
+ cfg.test_dataloader.batch_size = 1
58
+ for time_index in range(repeat_times):
59
+ print(f'Run {time_index + 1}:')
60
+ # build the dataloader
61
+ data_loader = Runner.build_dataloader(cfg.test_dataloader)
62
+
63
+ # build the model and load checkpoint
64
+ cfg.model.train_cfg = None
65
+ model = MODELS.build(cfg.model)
66
+
67
+ if 'checkpoint' in args and osp.exists(args.checkpoint):
68
+ load_checkpoint(model, args.checkpoint, map_location='cpu')
69
+
70
+ if torch.cuda.is_available():
71
+ model = model.cuda()
72
+
73
+ model = revert_sync_batchnorm(model)
74
+
75
+ model.eval()
76
+
77
+ # the first several iterations may be very slow so skip them
78
+ num_warmup = 5
79
+ pure_inf_time = 0
80
+ total_iters = 200
81
+
82
+ # benchmark with 200 batches and take the average
83
+ for i, data in enumerate(data_loader):
84
+ data = model.data_preprocessor(data, True)
85
+ inputs = data['inputs']
86
+ data_samples = data['data_samples']
87
+ if torch.cuda.is_available():
88
+ torch.cuda.synchronize()
89
+ start_time = time.perf_counter()
90
+
91
+ with torch.no_grad():
92
+ model(inputs, data_samples, mode='predict')
93
+
94
+ if torch.cuda.is_available():
95
+ torch.cuda.synchronize()
96
+ elapsed = time.perf_counter() - start_time
97
+
98
+ if i >= num_warmup:
99
+ pure_inf_time += elapsed
100
+ if (i + 1) % args.log_interval == 0:
101
+ fps = (i + 1 - num_warmup) / pure_inf_time
102
+ print(f'Done image [{i + 1:<3}/ {total_iters}], '
103
+ f'fps: {fps:.2f} img / s')
104
+
105
+ if (i + 1) == total_iters:
106
+ fps = (i + 1 - num_warmup) / pure_inf_time
107
+ print(f'Overall fps: {fps:.2f} img / s\n')
108
+ benchmark_dict[f'overall_fps_{time_index + 1}'] = round(fps, 2)
109
+ overall_fps_list.append(fps)
110
+ break
111
+ benchmark_dict['average_fps'] = round(np.mean(overall_fps_list), 2)
112
+ benchmark_dict['fps_variance'] = round(np.var(overall_fps_list), 4)
113
+ print(f'Average fps of {repeat_times} evaluations: '
114
+ f'{benchmark_dict["average_fps"]}')
115
+ print(f'The variance of {repeat_times} evaluations: '
116
+ f'{benchmark_dict["fps_variance"]}')
117
+ dump(benchmark_dict, json_file, indent=4)
118
+
119
+
120
+ if __name__ == '__main__':
121
+ main()
tools/analysis_tools/confusion_matrix.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import argparse
3
+ import os
4
+
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ from matplotlib.ticker import MultipleLocator
8
+ from mmengine.config import Config, DictAction
9
+ from mmengine.registry import init_default_scope
10
+ from mmengine.utils import mkdir_or_exist, progressbar
11
+ from PIL import Image
12
+
13
+ from mmseg.registry import DATASETS
14
+
15
+ init_default_scope('mmseg')
16
+
17
+
18
+ def parse_args():
19
+ parser = argparse.ArgumentParser(
20
+ description='Generate confusion matrix from segmentation results')
21
+ parser.add_argument('config', help='test config file path')
22
+ parser.add_argument(
23
+ 'prediction_path', help='prediction path where test folder result')
24
+ parser.add_argument(
25
+ 'save_dir', help='directory where confusion matrix will be saved')
26
+ parser.add_argument(
27
+ '--show', action='store_true', help='show confusion matrix')
28
+ parser.add_argument(
29
+ '--color-theme',
30
+ default='winter',
31
+ help='theme of the matrix color map')
32
+ parser.add_argument(
33
+ '--title',
34
+ default='Normalized Confusion Matrix',
35
+ help='title of the matrix color map')
36
+ parser.add_argument(
37
+ '--cfg-options',
38
+ nargs='+',
39
+ action=DictAction,
40
+ help='override some settings in the used config, the key-value pair '
41
+ 'in xxx=yyy format will be merged into config file. If the value to '
42
+ 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
43
+ 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
44
+ 'Note that the quotation marks are necessary and that no white space '
45
+ 'is allowed.')
46
+ args = parser.parse_args()
47
+ return args
48
+
49
+
50
+ def calculate_confusion_matrix(dataset, results):
51
+ """Calculate the confusion matrix.
52
+
53
+ Args:
54
+ dataset (Dataset): Test or val dataset.
55
+ results (list[ndarray]): A list of segmentation results in each image.
56
+ """
57
+ n = len(dataset.METAINFO['classes'])
58
+ confusion_matrix = np.zeros(shape=[n, n])
59
+ assert len(dataset) == len(results)
60
+ ignore_index = dataset.ignore_index
61
+ reduce_zero_label = dataset.reduce_zero_label
62
+ prog_bar = progressbar.ProgressBar(len(results))
63
+ for idx, per_img_res in enumerate(results):
64
+ res_segm = per_img_res
65
+ gt_segm = dataset[idx]['data_samples'] \
66
+ .gt_sem_seg.data.squeeze().numpy().astype(np.uint8)
67
+ gt_segm, res_segm = gt_segm.flatten(), res_segm.flatten()
68
+ if reduce_zero_label:
69
+ gt_segm = gt_segm - 1
70
+ to_ignore = gt_segm == ignore_index
71
+
72
+ gt_segm, res_segm = gt_segm[~to_ignore], res_segm[~to_ignore]
73
+ inds = n * gt_segm + res_segm
74
+ mat = np.bincount(inds, minlength=n**2).reshape(n, n)
75
+ confusion_matrix += mat
76
+ prog_bar.update()
77
+ return confusion_matrix
78
+
79
+
80
+ def plot_confusion_matrix(confusion_matrix,
81
+ labels,
82
+ save_dir=None,
83
+ show=True,
84
+ title='Normalized Confusion Matrix',
85
+ color_theme='OrRd'):
86
+ """Draw confusion matrix with matplotlib.
87
+
88
+ Args:
89
+ confusion_matrix (ndarray): The confusion matrix.
90
+ labels (list[str]): List of class names.
91
+ save_dir (str|optional): If set, save the confusion matrix plot to the
92
+ given path. Default: None.
93
+ show (bool): Whether to show the plot. Default: True.
94
+ title (str): Title of the plot. Default: `Normalized Confusion Matrix`.
95
+ color_theme (str): Theme of the matrix color map. Default: `winter`.
96
+ """
97
+ # normalize the confusion matrix
98
+ per_label_sums = confusion_matrix.sum(axis=1)[:, np.newaxis]
99
+ confusion_matrix = \
100
+ confusion_matrix.astype(np.float32) / per_label_sums * 100
101
+
102
+ num_classes = len(labels)
103
+ fig, ax = plt.subplots(
104
+ figsize=(2 * num_classes, 2 * num_classes * 0.8), dpi=300)
105
+ cmap = plt.get_cmap(color_theme)
106
+ im = ax.imshow(confusion_matrix, cmap=cmap)
107
+ colorbar = plt.colorbar(mappable=im, ax=ax)
108
+ colorbar.ax.tick_params(labelsize=20) # 设置 colorbar 标签的字体大小
109
+
110
+ title_font = {'weight': 'bold', 'size': 20}
111
+ ax.set_title(title, fontdict=title_font)
112
+ label_font = {'size': 40}
113
+ plt.ylabel('Ground Truth Label', fontdict=label_font)
114
+ plt.xlabel('Prediction Label', fontdict=label_font)
115
+
116
+ # draw locator
117
+ xmajor_locator = MultipleLocator(1)
118
+ xminor_locator = MultipleLocator(0.5)
119
+ ax.xaxis.set_major_locator(xmajor_locator)
120
+ ax.xaxis.set_minor_locator(xminor_locator)
121
+ ymajor_locator = MultipleLocator(1)
122
+ yminor_locator = MultipleLocator(0.5)
123
+ ax.yaxis.set_major_locator(ymajor_locator)
124
+ ax.yaxis.set_minor_locator(yminor_locator)
125
+
126
+ # draw grid
127
+ ax.grid(True, which='minor', linestyle='-')
128
+
129
+ # draw label
130
+ ax.set_xticks(np.arange(num_classes))
131
+ ax.set_yticks(np.arange(num_classes))
132
+ ax.set_xticklabels(labels, fontsize=20)
133
+ ax.set_yticklabels(labels, fontsize=20)
134
+
135
+ ax.tick_params(
136
+ axis='x', bottom=False, top=True, labelbottom=False, labeltop=True)
137
+ plt.setp(
138
+ ax.get_xticklabels(), rotation=45, ha='left', rotation_mode='anchor')
139
+
140
+ # draw confusion matrix value
141
+ for i in range(num_classes):
142
+ for j in range(num_classes):
143
+ ax.text(
144
+ j,
145
+ i,
146
+ '{}%'.format(
147
+ round(confusion_matrix[i, j], 2
148
+ ) if not np.isnan(confusion_matrix[i, j]) else -1),
149
+ ha='center',
150
+ va='center',
151
+ color='k',
152
+ size=20)
153
+
154
+ ax.set_ylim(len(confusion_matrix) - 0.5, -0.5) # matplotlib>3.1.1
155
+
156
+ fig.tight_layout()
157
+ if save_dir is not None:
158
+ mkdir_or_exist(save_dir)
159
+ plt.savefig(
160
+ os.path.join(save_dir, 'confusion_matrix.png'), format='png')
161
+ if show:
162
+ plt.show()
163
+
164
+
165
+ def main():
166
+ args = parse_args()
167
+
168
+ cfg = Config.fromfile(args.config)
169
+ if args.cfg_options is not None:
170
+ cfg.merge_from_dict(args.cfg_options)
171
+
172
+ results = []
173
+ for img in sorted(os.listdir(args.prediction_path)):
174
+ img = os.path.join(args.prediction_path, img)
175
+ image = Image.open(img)
176
+ image = np.copy(image)
177
+ results.append(image)
178
+
179
+ assert isinstance(results, list)
180
+ if isinstance(results[0], np.ndarray):
181
+ pass
182
+ else:
183
+ raise TypeError('invalid type of prediction results')
184
+
185
+ dataset = DATASETS.build(cfg.test_dataloader.dataset)
186
+ confusion_matrix = calculate_confusion_matrix(dataset, results)
187
+ plot_confusion_matrix(
188
+ confusion_matrix,
189
+ dataset.METAINFO['classes'],
190
+ save_dir=args.save_dir,
191
+ show=args.show,
192
+ title=args.title,
193
+ color_theme=args.color_theme)
194
+
195
+
196
+ if __name__ == '__main__':
197
+ main()
tools/analysis_tools/get_flops.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import argparse
3
+ import tempfile
4
+ from pathlib import Path
5
+
6
+ import torch
7
+ from mmengine import Config, DictAction
8
+ from mmengine.logging import MMLogger
9
+ from mmengine.model import revert_sync_batchnorm
10
+ from mmengine.registry import init_default_scope
11
+
12
+ from mmseg.models import BaseSegmentor
13
+ from mmseg.registry import MODELS
14
+ from mmseg.structures import SegDataSample
15
+ from vegseg import models
16
+ try:
17
+ from mmengine.analysis import get_model_complexity_info
18
+ from mmengine.analysis.print_helper import _format_size
19
+ except ImportError:
20
+ raise ImportError('Please upgrade mmengine >= 0.6.0 to use this script.')
21
+
22
+
23
+ def parse_args():
24
+ parser = argparse.ArgumentParser(
25
+ description='Get the FLOPs of a segmentor')
26
+ parser.add_argument('config', help='train config file path')
27
+ parser.add_argument(
28
+ '--shape',
29
+ type=int,
30
+ nargs='+',
31
+ default=[2048, 1024],
32
+ help='input image size')
33
+ parser.add_argument(
34
+ '--cfg-options',
35
+ nargs='+',
36
+ action=DictAction,
37
+ help='override some settings in the used config, the key-value pair '
38
+ 'in xxx=yyy format will be merged into config file. If the value to '
39
+ 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
40
+ 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
41
+ 'Note that the quotation marks are necessary and that no white space '
42
+ 'is allowed.')
43
+ args = parser.parse_args()
44
+ return args
45
+
46
+
47
+ def inference(args: argparse.Namespace, logger: MMLogger) -> dict:
48
+ config_name = Path(args.config)
49
+
50
+ if not config_name.exists():
51
+ logger.error(f'Config file {config_name} does not exist')
52
+
53
+ cfg: Config = Config.fromfile(config_name)
54
+ cfg.work_dir = tempfile.TemporaryDirectory().name
55
+ cfg.log_level = 'WARN'
56
+ if args.cfg_options is not None:
57
+ cfg.merge_from_dict(args.cfg_options)
58
+
59
+ init_default_scope(cfg.get('scope', 'mmseg'))
60
+
61
+ if len(args.shape) == 1:
62
+ input_shape = (3, args.shape[0], args.shape[0])
63
+ elif len(args.shape) == 2:
64
+ input_shape = (3, ) + tuple(args.shape)
65
+ else:
66
+ raise ValueError('invalid input shape')
67
+ result = {}
68
+
69
+ model: BaseSegmentor = MODELS.build(cfg.model)
70
+ if hasattr(model, 'auxiliary_head'):
71
+ model.auxiliary_head = None
72
+ if hasattr(model, 'teach_backbone'):
73
+ model.teach_backbone = None
74
+ if torch.cuda.is_available():
75
+ model.cuda()
76
+ model = revert_sync_batchnorm(model)
77
+ result['ori_shape'] = input_shape[-2:]
78
+ result['pad_shape'] = input_shape[-2:]
79
+ data_batch = {
80
+ 'inputs': [torch.rand(input_shape)],
81
+ 'data_samples': [SegDataSample(metainfo=result)]
82
+ }
83
+ data = model.data_preprocessor(data_batch)
84
+ model.eval()
85
+ if cfg.model.decode_head.type in ['MaskFormerHead', 'Mask2FormerHead']:
86
+ # TODO: Support MaskFormer and Mask2Former
87
+ raise NotImplementedError('MaskFormer and Mask2Former are not '
88
+ 'supported yet.')
89
+ outputs = get_model_complexity_info(
90
+ model,
91
+ input_shape=None,
92
+ inputs=data['inputs'],
93
+ show_table=False,
94
+ show_arch=False)
95
+ result['flops'] = _format_size(outputs['flops'])
96
+ result['params'] = _format_size(outputs['params'])
97
+ result['compute_type'] = 'direct: randomly generate a picture'
98
+ return result
99
+
100
+
101
+ def main():
102
+
103
+ args = parse_args()
104
+ logger = MMLogger.get_instance(name='MMLogger')
105
+
106
+ result = inference(args, logger)
107
+ split_line = '=' * 30
108
+ ori_shape = result['ori_shape']
109
+ pad_shape = result['pad_shape']
110
+ flops = result['flops']
111
+ params = result['params']
112
+ compute_type = result['compute_type']
113
+
114
+ if pad_shape != ori_shape:
115
+ print(f'{split_line}\nUse size divisor set input shape '
116
+ f'from {ori_shape} to {pad_shape}')
117
+ print(f'{split_line}\nCompute type: {compute_type}\n'
118
+ f'Input shape: {pad_shape}\nFlops: {flops}\n'
119
+ f'Params: {params}\n{split_line}')
120
+ print('!!!Please be cautious if you use the results in papers. '
121
+ 'You may need to check if all ops are supported and verify '
122
+ 'that the flops computation is correct.')
123
+
124
+
125
+ if __name__ == '__main__':
126
+ main()
tools/analysis_tools/visualization_cam.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ """Use the pytorch-grad-cam tool to visualize Class Activation Maps (CAM).
3
+
4
+ requirement: pip install grad-cam
5
+ """
6
+
7
+ from argparse import ArgumentParser
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from mmengine import Config
13
+ from mmengine.model import revert_sync_batchnorm
14
+ from PIL import Image
15
+ from pytorch_grad_cam import GradCAM
16
+ from pytorch_grad_cam.utils.image import preprocess_image, show_cam_on_image
17
+
18
+ from mmseg.apis import inference_model, init_model, show_result_pyplot
19
+ from mmseg.utils import register_all_modules
20
+
21
+
22
+ class SemanticSegmentationTarget:
23
+ """wrap the model.
24
+
25
+ requirement: pip install grad-cam
26
+
27
+ Args:
28
+ category (int): Visualization class.
29
+ mask (ndarray): Mask of class.
30
+ size (tuple): Image size.
31
+ """
32
+
33
+ def __init__(self, category, mask, size):
34
+ self.category = category
35
+ self.mask = torch.from_numpy(mask)
36
+ self.size = size
37
+ if torch.cuda.is_available():
38
+ self.mask = self.mask.cuda()
39
+
40
+ def __call__(self, model_output):
41
+ model_output = torch.unsqueeze(model_output, dim=0)
42
+ model_output = F.interpolate(
43
+ model_output, size=self.size, mode='bilinear')
44
+ model_output = torch.squeeze(model_output, dim=0)
45
+
46
+ return (model_output[self.category, :, :] * self.mask).sum()
47
+
48
+
49
+ def main():
50
+ parser = ArgumentParser()
51
+ parser.add_argument('img', help='Image file')
52
+ parser.add_argument('config', help='Config file')
53
+ parser.add_argument('checkpoint', help='Checkpoint file')
54
+ parser.add_argument(
55
+ '--out-file',
56
+ default='prediction.png',
57
+ help='Path to output prediction file')
58
+ parser.add_argument(
59
+ '--cam-file', default='vis_cam.png', help='Path to output cam file')
60
+ parser.add_argument(
61
+ '--target-layers',
62
+ default='backbone.layer4[2]',
63
+ help='Target layers to visualize CAM')
64
+ parser.add_argument(
65
+ '--category-index', default='7', help='Category to visualize CAM')
66
+ parser.add_argument(
67
+ '--device', default='cuda:0', help='Device used for inference')
68
+ args = parser.parse_args()
69
+
70
+ # build the model from a config file and a checkpoint file
71
+ register_all_modules()
72
+ model = init_model(args.config, args.checkpoint, device=args.device)
73
+ if args.device == 'cpu':
74
+ model = revert_sync_batchnorm(model)
75
+
76
+ # test a single image
77
+ result = inference_model(model, args.img)
78
+
79
+ # show the results
80
+ show_result_pyplot(
81
+ model,
82
+ args.img,
83
+ result,
84
+ draw_gt=False,
85
+ show=False if args.out_file is not None else True,
86
+ out_file=args.out_file)
87
+
88
+ # result data conversion
89
+ prediction_data = result.pred_sem_seg.data
90
+ pre_np_data = prediction_data.cpu().numpy().squeeze(0)
91
+
92
+ target_layers = args.target_layers
93
+ target_layers = [eval(f'model.{target_layers}')]
94
+
95
+ category = int(args.category_index)
96
+ mask_float = np.float32(pre_np_data == category)
97
+
98
+ # data processing
99
+ image = np.array(Image.open(args.img).convert('RGB'))
100
+ height, width = image.shape[0], image.shape[1]
101
+ rgb_img = np.float32(image) / 255
102
+ config = Config.fromfile(args.config)
103
+ image_mean = config.data_preprocessor['mean']
104
+ image_std = config.data_preprocessor['std']
105
+ input_tensor = preprocess_image(
106
+ rgb_img,
107
+ mean=[x / 255 for x in image_mean],
108
+ std=[x / 255 for x in image_std])
109
+
110
+ # Grad CAM(Class Activation Maps)
111
+ # Can also be LayerCAM, XGradCAM, GradCAMPlusPlus, EigenCAM, EigenGradCAM
112
+ targets = [
113
+ SemanticSegmentationTarget(category, mask_float, (height, width))
114
+ ]
115
+ with GradCAM(
116
+ model=model,
117
+ target_layers=target_layers,
118
+ use_cuda=torch.cuda.is_available()) as cam:
119
+ grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0, :]
120
+ cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
121
+
122
+ # save cam file
123
+ Image.fromarray(cam_image).save(args.cam_file)
124
+
125
+
126
+ if __name__ == '__main__':
127
+ main()
tools/dataset_converters/chase_db1.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import argparse
3
+ import os
4
+ import os.path as osp
5
+ import tempfile
6
+ import zipfile
7
+
8
+ import mmcv
9
+ from mmengine.utils import mkdir_or_exist
10
+
11
+ CHASE_DB1_LEN = 28 * 3
12
+ TRAINING_LEN = 60
13
+
14
+
15
+ def parse_args():
16
+ parser = argparse.ArgumentParser(
17
+ description='Convert CHASE_DB1 dataset to mmsegmentation format')
18
+ parser.add_argument('dataset_path', help='path of CHASEDB1.zip')
19
+ parser.add_argument('--tmp_dir', help='path of the temporary directory')
20
+ parser.add_argument('-o', '--out_dir', help='output path')
21
+ args = parser.parse_args()
22
+ return args
23
+
24
+
25
+ def main():
26
+ args = parse_args()
27
+ dataset_path = args.dataset_path
28
+ if args.out_dir is None:
29
+ out_dir = osp.join('data', 'CHASE_DB1')
30
+ else:
31
+ out_dir = args.out_dir
32
+
33
+ print('Making directories...')
34
+ mkdir_or_exist(out_dir)
35
+ mkdir_or_exist(osp.join(out_dir, 'images'))
36
+ mkdir_or_exist(osp.join(out_dir, 'images', 'training'))
37
+ mkdir_or_exist(osp.join(out_dir, 'images', 'validation'))
38
+ mkdir_or_exist(osp.join(out_dir, 'annotations'))
39
+ mkdir_or_exist(osp.join(out_dir, 'annotations', 'training'))
40
+ mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation'))
41
+
42
+ with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
43
+ print('Extracting CHASEDB1.zip...')
44
+ zip_file = zipfile.ZipFile(dataset_path)
45
+ zip_file.extractall(tmp_dir)
46
+
47
+ print('Generating training dataset...')
48
+
49
+ assert len(os.listdir(tmp_dir)) == CHASE_DB1_LEN, \
50
+ f'len(os.listdir(tmp_dir)) != {CHASE_DB1_LEN}'
51
+
52
+ for img_name in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]:
53
+ img = mmcv.imread(osp.join(tmp_dir, img_name))
54
+ if osp.splitext(img_name)[1] == '.jpg':
55
+ mmcv.imwrite(
56
+ img,
57
+ osp.join(out_dir, 'images', 'training',
58
+ osp.splitext(img_name)[0] + '.png'))
59
+ else:
60
+ # The annotation img should be divided by 128, because some of
61
+ # the annotation imgs are not standard. We should set a
62
+ # threshold to convert the nonstandard annotation imgs. The
63
+ # value divided by 128 is equivalent to '1 if value >= 128
64
+ # else 0'
65
+ mmcv.imwrite(
66
+ img[:, :, 0] // 128,
67
+ osp.join(out_dir, 'annotations', 'training',
68
+ osp.splitext(img_name)[0] + '.png'))
69
+
70
+ for img_name in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]:
71
+ img = mmcv.imread(osp.join(tmp_dir, img_name))
72
+ if osp.splitext(img_name)[1] == '.jpg':
73
+ mmcv.imwrite(
74
+ img,
75
+ osp.join(out_dir, 'images', 'validation',
76
+ osp.splitext(img_name)[0] + '.png'))
77
+ else:
78
+ mmcv.imwrite(
79
+ img[:, :, 0] // 128,
80
+ osp.join(out_dir, 'annotations', 'validation',
81
+ osp.splitext(img_name)[0] + '.png'))
82
+
83
+ print('Removing the temporary files...')
84
+
85
+ print('Done!')
86
+
87
+
88
+ if __name__ == '__main__':
89
+ main()
tools/dataset_converters/cityscapes.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import argparse
3
+ import os.path as osp
4
+
5
+ from cityscapesscripts.preparation.json2labelImg import json2labelImg
6
+ from mmengine.utils import (mkdir_or_exist, scandir, track_parallel_progress,
7
+ track_progress)
8
+
9
+
10
+ def convert_json_to_label(json_file):
11
+ label_file = json_file.replace('_polygons.json', '_labelTrainIds.png')
12
+ json2labelImg(json_file, label_file, 'trainIds')
13
+
14
+
15
+ def parse_args():
16
+ parser = argparse.ArgumentParser(
17
+ description='Convert Cityscapes annotations to TrainIds')
18
+ parser.add_argument('cityscapes_path', help='cityscapes data path')
19
+ parser.add_argument('--gt-dir', default='gtFine', type=str)
20
+ parser.add_argument('-o', '--out-dir', help='output path')
21
+ parser.add_argument(
22
+ '--nproc', default=1, type=int, help='number of process')
23
+ args = parser.parse_args()
24
+ return args
25
+
26
+
27
+ def main():
28
+ args = parse_args()
29
+ cityscapes_path = args.cityscapes_path
30
+ out_dir = args.out_dir if args.out_dir else cityscapes_path
31
+ mkdir_or_exist(out_dir)
32
+
33
+ gt_dir = osp.join(cityscapes_path, args.gt_dir)
34
+
35
+ poly_files = []
36
+ for poly in scandir(gt_dir, '_polygons.json', recursive=True):
37
+ poly_file = osp.join(gt_dir, poly)
38
+ poly_files.append(poly_file)
39
+ if args.nproc > 1:
40
+ track_parallel_progress(convert_json_to_label, poly_files, args.nproc)
41
+ else:
42
+ track_progress(convert_json_to_label, poly_files)
43
+
44
+ split_names = ['train', 'val', 'test']
45
+
46
+ for split in split_names:
47
+ filenames = []
48
+ for poly in scandir(
49
+ osp.join(gt_dir, split), '_polygons.json', recursive=True):
50
+ filenames.append(poly.replace('_gtFine_polygons.json', ''))
51
+ with open(osp.join(out_dir, f'{split}.txt'), 'w') as f:
52
+ f.writelines(f + '\n' for f in filenames)
53
+
54
+
55
+ if __name__ == '__main__':
56
+ main()
tools/dataset_converters/coco_stuff10k.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import argparse
3
+ import os.path as osp
4
+ import shutil
5
+ from functools import partial
6
+
7
+ import numpy as np
8
+ from mmengine.utils import (mkdir_or_exist, track_parallel_progress,
9
+ track_progress)
10
+ from PIL import Image
11
+ from scipy.io import loadmat
12
+
13
+ COCO_LEN = 10000
14
+
15
+ clsID_to_trID = {
16
+ 0: 0,
17
+ 1: 1,
18
+ 2: 2,
19
+ 3: 3,
20
+ 4: 4,
21
+ 5: 5,
22
+ 6: 6,
23
+ 7: 7,
24
+ 8: 8,
25
+ 9: 9,
26
+ 10: 10,
27
+ 11: 11,
28
+ 13: 12,
29
+ 14: 13,
30
+ 15: 14,
31
+ 16: 15,
32
+ 17: 16,
33
+ 18: 17,
34
+ 19: 18,
35
+ 20: 19,
36
+ 21: 20,
37
+ 22: 21,
38
+ 23: 22,
39
+ 24: 23,
40
+ 25: 24,
41
+ 27: 25,
42
+ 28: 26,
43
+ 31: 27,
44
+ 32: 28,
45
+ 33: 29,
46
+ 34: 30,
47
+ 35: 31,
48
+ 36: 32,
49
+ 37: 33,
50
+ 38: 34,
51
+ 39: 35,
52
+ 40: 36,
53
+ 41: 37,
54
+ 42: 38,
55
+ 43: 39,
56
+ 44: 40,
57
+ 46: 41,
58
+ 47: 42,
59
+ 48: 43,
60
+ 49: 44,
61
+ 50: 45,
62
+ 51: 46,
63
+ 52: 47,
64
+ 53: 48,
65
+ 54: 49,
66
+ 55: 50,
67
+ 56: 51,
68
+ 57: 52,
69
+ 58: 53,
70
+ 59: 54,
71
+ 60: 55,
72
+ 61: 56,
73
+ 62: 57,
74
+ 63: 58,
75
+ 64: 59,
76
+ 65: 60,
77
+ 67: 61,
78
+ 70: 62,
79
+ 72: 63,
80
+ 73: 64,
81
+ 74: 65,
82
+ 75: 66,
83
+ 76: 67,
84
+ 77: 68,
85
+ 78: 69,
86
+ 79: 70,
87
+ 80: 71,
88
+ 81: 72,
89
+ 82: 73,
90
+ 84: 74,
91
+ 85: 75,
92
+ 86: 76,
93
+ 87: 77,
94
+ 88: 78,
95
+ 89: 79,
96
+ 90: 80,
97
+ 92: 81,
98
+ 93: 82,
99
+ 94: 83,
100
+ 95: 84,
101
+ 96: 85,
102
+ 97: 86,
103
+ 98: 87,
104
+ 99: 88,
105
+ 100: 89,
106
+ 101: 90,
107
+ 102: 91,
108
+ 103: 92,
109
+ 104: 93,
110
+ 105: 94,
111
+ 106: 95,
112
+ 107: 96,
113
+ 108: 97,
114
+ 109: 98,
115
+ 110: 99,
116
+ 111: 100,
117
+ 112: 101,
118
+ 113: 102,
119
+ 114: 103,
120
+ 115: 104,
121
+ 116: 105,
122
+ 117: 106,
123
+ 118: 107,
124
+ 119: 108,
125
+ 120: 109,
126
+ 121: 110,
127
+ 122: 111,
128
+ 123: 112,
129
+ 124: 113,
130
+ 125: 114,
131
+ 126: 115,
132
+ 127: 116,
133
+ 128: 117,
134
+ 129: 118,
135
+ 130: 119,
136
+ 131: 120,
137
+ 132: 121,
138
+ 133: 122,
139
+ 134: 123,
140
+ 135: 124,
141
+ 136: 125,
142
+ 137: 126,
143
+ 138: 127,
144
+ 139: 128,
145
+ 140: 129,
146
+ 141: 130,
147
+ 142: 131,
148
+ 143: 132,
149
+ 144: 133,
150
+ 145: 134,
151
+ 146: 135,
152
+ 147: 136,
153
+ 148: 137,
154
+ 149: 138,
155
+ 150: 139,
156
+ 151: 140,
157
+ 152: 141,
158
+ 153: 142,
159
+ 154: 143,
160
+ 155: 144,
161
+ 156: 145,
162
+ 157: 146,
163
+ 158: 147,
164
+ 159: 148,
165
+ 160: 149,
166
+ 161: 150,
167
+ 162: 151,
168
+ 163: 152,
169
+ 164: 153,
170
+ 165: 154,
171
+ 166: 155,
172
+ 167: 156,
173
+ 168: 157,
174
+ 169: 158,
175
+ 170: 159,
176
+ 171: 160,
177
+ 172: 161,
178
+ 173: 162,
179
+ 174: 163,
180
+ 175: 164,
181
+ 176: 165,
182
+ 177: 166,
183
+ 178: 167,
184
+ 179: 168,
185
+ 180: 169,
186
+ 181: 170,
187
+ 182: 171
188
+ }
189
+
190
+
191
+ def convert_to_trainID(tuple_path, in_img_dir, in_ann_dir, out_img_dir,
192
+ out_mask_dir, is_train):
193
+ imgpath, maskpath = tuple_path
194
+ shutil.copyfile(
195
+ osp.join(in_img_dir, imgpath),
196
+ osp.join(out_img_dir, 'train2014', imgpath) if is_train else osp.join(
197
+ out_img_dir, 'test2014', imgpath))
198
+ annotate = loadmat(osp.join(in_ann_dir, maskpath))
199
+ mask = annotate['S'].astype(np.uint8)
200
+ mask_copy = mask.copy()
201
+ for clsID, trID in clsID_to_trID.items():
202
+ mask_copy[mask == clsID] = trID
203
+ seg_filename = osp.join(out_mask_dir, 'train2014',
204
+ maskpath.split('.')[0] +
205
+ '_labelTrainIds.png') if is_train else osp.join(
206
+ out_mask_dir, 'test2014',
207
+ maskpath.split('.')[0] + '_labelTrainIds.png')
208
+ Image.fromarray(mask_copy).save(seg_filename, 'PNG')
209
+
210
+
211
+ def generate_coco_list(folder):
212
+ train_list = osp.join(folder, 'imageLists', 'train.txt')
213
+ test_list = osp.join(folder, 'imageLists', 'test.txt')
214
+ train_paths = []
215
+ test_paths = []
216
+
217
+ with open(train_list) as f:
218
+ for filename in f:
219
+ basename = filename.strip()
220
+ imgpath = basename + '.jpg'
221
+ maskpath = basename + '.mat'
222
+ train_paths.append((imgpath, maskpath))
223
+
224
+ with open(test_list) as f:
225
+ for filename in f:
226
+ basename = filename.strip()
227
+ imgpath = basename + '.jpg'
228
+ maskpath = basename + '.mat'
229
+ test_paths.append((imgpath, maskpath))
230
+
231
+ return train_paths, test_paths
232
+
233
+
234
+ def parse_args():
235
+ parser = argparse.ArgumentParser(
236
+ description=\
237
+ 'Convert COCO Stuff 10k annotations to mmsegmentation format') # noqa
238
+ parser.add_argument('coco_path', help='coco stuff path')
239
+ parser.add_argument('-o', '--out_dir', help='output path')
240
+ parser.add_argument(
241
+ '--nproc', default=16, type=int, help='number of process')
242
+ args = parser.parse_args()
243
+ return args
244
+
245
+
246
+ def main():
247
+ args = parse_args()
248
+ coco_path = args.coco_path
249
+ nproc = args.nproc
250
+
251
+ out_dir = args.out_dir or coco_path
252
+ out_img_dir = osp.join(out_dir, 'images')
253
+ out_mask_dir = osp.join(out_dir, 'annotations')
254
+
255
+ mkdir_or_exist(osp.join(out_img_dir, 'train2014'))
256
+ mkdir_or_exist(osp.join(out_img_dir, 'test2014'))
257
+ mkdir_or_exist(osp.join(out_mask_dir, 'train2014'))
258
+ mkdir_or_exist(osp.join(out_mask_dir, 'test2014'))
259
+
260
+ train_list, test_list = generate_coco_list(coco_path)
261
+ assert (len(train_list) +
262
+ len(test_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format(
263
+ len(train_list), len(test_list))
264
+
265
+ if args.nproc > 1:
266
+ track_parallel_progress(
267
+ partial(
268
+ convert_to_trainID,
269
+ in_img_dir=osp.join(coco_path, 'images'),
270
+ in_ann_dir=osp.join(coco_path, 'annotations'),
271
+ out_img_dir=out_img_dir,
272
+ out_mask_dir=out_mask_dir,
273
+ is_train=True),
274
+ train_list,
275
+ nproc=nproc)
276
+ track_parallel_progress(
277
+ partial(
278
+ convert_to_trainID,
279
+ in_img_dir=osp.join(coco_path, 'images'),
280
+ in_ann_dir=osp.join(coco_path, 'annotations'),
281
+ out_img_dir=out_img_dir,
282
+ out_mask_dir=out_mask_dir,
283
+ is_train=False),
284
+ test_list,
285
+ nproc=nproc)
286
+ else:
287
+ track_progress(
288
+ partial(
289
+ convert_to_trainID,
290
+ in_img_dir=osp.join(coco_path, 'images'),
291
+ in_ann_dir=osp.join(coco_path, 'annotations'),
292
+ out_img_dir=out_img_dir,
293
+ out_mask_dir=out_mask_dir,
294
+ is_train=True), train_list)
295
+ track_progress(
296
+ partial(
297
+ convert_to_trainID,
298
+ in_img_dir=osp.join(coco_path, 'images'),
299
+ in_ann_dir=osp.join(coco_path, 'annotations'),
300
+ out_img_dir=out_img_dir,
301
+ out_mask_dir=out_mask_dir,
302
+ is_train=False), test_list)
303
+
304
+ print('Done!')
305
+
306
+
307
+ if __name__ == '__main__':
308
+ main()
tools/dataset_converters/coco_stuff164k.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import argparse
3
+ import os.path as osp
4
+ import shutil
5
+ from functools import partial
6
+ from glob import glob
7
+
8
+ import numpy as np
9
+ from mmengine.utils import (mkdir_or_exist, track_parallel_progress,
10
+ track_progress)
11
+ from PIL import Image
12
+
13
+ COCO_LEN = 123287
14
+
15
+ clsID_to_trID = {
16
+ 0: 0,
17
+ 1: 1,
18
+ 2: 2,
19
+ 3: 3,
20
+ 4: 4,
21
+ 5: 5,
22
+ 6: 6,
23
+ 7: 7,
24
+ 8: 8,
25
+ 9: 9,
26
+ 10: 10,
27
+ 12: 11,
28
+ 13: 12,
29
+ 14: 13,
30
+ 15: 14,
31
+ 16: 15,
32
+ 17: 16,
33
+ 18: 17,
34
+ 19: 18,
35
+ 20: 19,
36
+ 21: 20,
37
+ 22: 21,
38
+ 23: 22,
39
+ 24: 23,
40
+ 26: 24,
41
+ 27: 25,
42
+ 30: 26,
43
+ 31: 27,
44
+ 32: 28,
45
+ 33: 29,
46
+ 34: 30,
47
+ 35: 31,
48
+ 36: 32,
49
+ 37: 33,
50
+ 38: 34,
51
+ 39: 35,
52
+ 40: 36,
53
+ 41: 37,
54
+ 42: 38,
55
+ 43: 39,
56
+ 45: 40,
57
+ 46: 41,
58
+ 47: 42,
59
+ 48: 43,
60
+ 49: 44,
61
+ 50: 45,
62
+ 51: 46,
63
+ 52: 47,
64
+ 53: 48,
65
+ 54: 49,
66
+ 55: 50,
67
+ 56: 51,
68
+ 57: 52,
69
+ 58: 53,
70
+ 59: 54,
71
+ 60: 55,
72
+ 61: 56,
73
+ 62: 57,
74
+ 63: 58,
75
+ 64: 59,
76
+ 66: 60,
77
+ 69: 61,
78
+ 71: 62,
79
+ 72: 63,
80
+ 73: 64,
81
+ 74: 65,
82
+ 75: 66,
83
+ 76: 67,
84
+ 77: 68,
85
+ 78: 69,
86
+ 79: 70,
87
+ 80: 71,
88
+ 81: 72,
89
+ 83: 73,
90
+ 84: 74,
91
+ 85: 75,
92
+ 86: 76,
93
+ 87: 77,
94
+ 88: 78,
95
+ 89: 79,
96
+ 91: 80,
97
+ 92: 81,
98
+ 93: 82,
99
+ 94: 83,
100
+ 95: 84,
101
+ 96: 85,
102
+ 97: 86,
103
+ 98: 87,
104
+ 99: 88,
105
+ 100: 89,
106
+ 101: 90,
107
+ 102: 91,
108
+ 103: 92,
109
+ 104: 93,
110
+ 105: 94,
111
+ 106: 95,
112
+ 107: 96,
113
+ 108: 97,
114
+ 109: 98,
115
+ 110: 99,
116
+ 111: 100,
117
+ 112: 101,
118
+ 113: 102,
119
+ 114: 103,
120
+ 115: 104,
121
+ 116: 105,
122
+ 117: 106,
123
+ 118: 107,
124
+ 119: 108,
125
+ 120: 109,
126
+ 121: 110,
127
+ 122: 111,
128
+ 123: 112,
129
+ 124: 113,
130
+ 125: 114,
131
+ 126: 115,
132
+ 127: 116,
133
+ 128: 117,
134
+ 129: 118,
135
+ 130: 119,
136
+ 131: 120,
137
+ 132: 121,
138
+ 133: 122,
139
+ 134: 123,
140
+ 135: 124,
141
+ 136: 125,
142
+ 137: 126,
143
+ 138: 127,
144
+ 139: 128,
145
+ 140: 129,
146
+ 141: 130,
147
+ 142: 131,
148
+ 143: 132,
149
+ 144: 133,
150
+ 145: 134,
151
+ 146: 135,
152
+ 147: 136,
153
+ 148: 137,
154
+ 149: 138,
155
+ 150: 139,
156
+ 151: 140,
157
+ 152: 141,
158
+ 153: 142,
159
+ 154: 143,
160
+ 155: 144,
161
+ 156: 145,
162
+ 157: 146,
163
+ 158: 147,
164
+ 159: 148,
165
+ 160: 149,
166
+ 161: 150,
167
+ 162: 151,
168
+ 163: 152,
169
+ 164: 153,
170
+ 165: 154,
171
+ 166: 155,
172
+ 167: 156,
173
+ 168: 157,
174
+ 169: 158,
175
+ 170: 159,
176
+ 171: 160,
177
+ 172: 161,
178
+ 173: 162,
179
+ 174: 163,
180
+ 175: 164,
181
+ 176: 165,
182
+ 177: 166,
183
+ 178: 167,
184
+ 179: 168,
185
+ 180: 169,
186
+ 181: 170,
187
+ 255: 255
188
+ }
189
+
190
+
191
+ def convert_to_trainID(maskpath, out_mask_dir, is_train):
192
+ mask = np.array(Image.open(maskpath))
193
+ mask_copy = mask.copy()
194
+ for clsID, trID in clsID_to_trID.items():
195
+ mask_copy[mask == clsID] = trID
196
+ seg_filename = osp.join(
197
+ out_mask_dir, 'train2017',
198
+ osp.basename(maskpath).split('.')[0] +
199
+ '_labelTrainIds.png') if is_train else osp.join(
200
+ out_mask_dir, 'val2017',
201
+ osp.basename(maskpath).split('.')[0] + '_labelTrainIds.png')
202
+ Image.fromarray(mask_copy).save(seg_filename, 'PNG')
203
+
204
+
205
+ def parse_args():
206
+ parser = argparse.ArgumentParser(
207
+ description=\
208
+ 'Convert COCO Stuff 164k annotations to mmsegmentation format') # noqa
209
+ parser.add_argument('coco_path', help='coco stuff path')
210
+ parser.add_argument('-o', '--out_dir', help='output path')
211
+ parser.add_argument(
212
+ '--nproc', default=16, type=int, help='number of process')
213
+ args = parser.parse_args()
214
+ return args
215
+
216
+
217
+ def main():
218
+ args = parse_args()
219
+ coco_path = args.coco_path
220
+ nproc = args.nproc
221
+
222
+ out_dir = args.out_dir or coco_path
223
+ out_img_dir = osp.join(out_dir, 'images')
224
+ out_mask_dir = osp.join(out_dir, 'annotations')
225
+
226
+ mkdir_or_exist(osp.join(out_mask_dir, 'train2017'))
227
+ mkdir_or_exist(osp.join(out_mask_dir, 'val2017'))
228
+
229
+ if out_dir != coco_path:
230
+ shutil.copytree(osp.join(coco_path, 'images'), out_img_dir)
231
+
232
+ train_list = glob(osp.join(coco_path, 'annotations', 'train2017', '*.png'))
233
+ train_list = [file for file in train_list if '_labelTrainIds' not in file]
234
+ test_list = glob(osp.join(coco_path, 'annotations', 'val2017', '*.png'))
235
+ test_list = [file for file in test_list if '_labelTrainIds' not in file]
236
+ assert (len(train_list) +
237
+ len(test_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format(
238
+ len(train_list), len(test_list))
239
+
240
+ if args.nproc > 1:
241
+ track_parallel_progress(
242
+ partial(
243
+ convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True),
244
+ train_list,
245
+ nproc=nproc)
246
+ track_parallel_progress(
247
+ partial(
248
+ convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False),
249
+ test_list,
250
+ nproc=nproc)
251
+ else:
252
+ track_progress(
253
+ partial(
254
+ convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True),
255
+ train_list)
256
+ track_progress(
257
+ partial(
258
+ convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False),
259
+ test_list)
260
+
261
+ print('Done!')
262
+
263
+
264
+ if __name__ == '__main__':
265
+ main()
tools/dataset_converters/hrf.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import argparse
3
+ import os
4
+ import os.path as osp
5
+ import tempfile
6
+ import zipfile
7
+
8
+ import mmcv
9
+ from mmengine.utils import mkdir_or_exist
10
+
11
+ HRF_LEN = 15
12
+ TRAINING_LEN = 5
13
+
14
+
15
+ def parse_args():
16
+ parser = argparse.ArgumentParser(
17
+ description='Convert HRF dataset to mmsegmentation format')
18
+ parser.add_argument('healthy_path', help='the path of healthy.zip')
19
+ parser.add_argument(
20
+ 'healthy_manualsegm_path', help='the path of healthy_manualsegm.zip')
21
+ parser.add_argument('glaucoma_path', help='the path of glaucoma.zip')
22
+ parser.add_argument(
23
+ 'glaucoma_manualsegm_path', help='the path of glaucoma_manualsegm.zip')
24
+ parser.add_argument(
25
+ 'diabetic_retinopathy_path',
26
+ help='the path of diabetic_retinopathy.zip')
27
+ parser.add_argument(
28
+ 'diabetic_retinopathy_manualsegm_path',
29
+ help='the path of diabetic_retinopathy_manualsegm.zip')
30
+ parser.add_argument('--tmp_dir', help='path of the temporary directory')
31
+ parser.add_argument('-o', '--out_dir', help='output path')
32
+ args = parser.parse_args()
33
+ return args
34
+
35
+
36
+ def main():
37
+ args = parse_args()
38
+ images_path = [
39
+ args.healthy_path, args.glaucoma_path, args.diabetic_retinopathy_path
40
+ ]
41
+ annotations_path = [
42
+ args.healthy_manualsegm_path, args.glaucoma_manualsegm_path,
43
+ args.diabetic_retinopathy_manualsegm_path
44
+ ]
45
+ if args.out_dir is None:
46
+ out_dir = osp.join('data', 'HRF')
47
+ else:
48
+ out_dir = args.out_dir
49
+
50
+ print('Making directories...')
51
+ mkdir_or_exist(out_dir)
52
+ mkdir_or_exist(osp.join(out_dir, 'images'))
53
+ mkdir_or_exist(osp.join(out_dir, 'images', 'training'))
54
+ mkdir_or_exist(osp.join(out_dir, 'images', 'validation'))
55
+ mkdir_or_exist(osp.join(out_dir, 'annotations'))
56
+ mkdir_or_exist(osp.join(out_dir, 'annotations', 'training'))
57
+ mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation'))
58
+
59
+ print('Generating images...')
60
+ for now_path in images_path:
61
+ with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
62
+ zip_file = zipfile.ZipFile(now_path)
63
+ zip_file.extractall(tmp_dir)
64
+
65
+ assert len(os.listdir(tmp_dir)) == HRF_LEN, \
66
+ f'len(os.listdir(tmp_dir)) != {HRF_LEN}'
67
+
68
+ for filename in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]:
69
+ img = mmcv.imread(osp.join(tmp_dir, filename))
70
+ mmcv.imwrite(
71
+ img,
72
+ osp.join(out_dir, 'images', 'training',
73
+ osp.splitext(filename)[0] + '.png'))
74
+ for filename in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]:
75
+ img = mmcv.imread(osp.join(tmp_dir, filename))
76
+ mmcv.imwrite(
77
+ img,
78
+ osp.join(out_dir, 'images', 'validation',
79
+ osp.splitext(filename)[0] + '.png'))
80
+
81
+ print('Generating annotations...')
82
+ for now_path in annotations_path:
83
+ with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
84
+ zip_file = zipfile.ZipFile(now_path)
85
+ zip_file.extractall(tmp_dir)
86
+
87
+ assert len(os.listdir(tmp_dir)) == HRF_LEN, \
88
+ f'len(os.listdir(tmp_dir)) != {HRF_LEN}'
89
+
90
+ for filename in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]:
91
+ img = mmcv.imread(osp.join(tmp_dir, filename))
92
+ # The annotation img should be divided by 128, because some of
93
+ # the annotation imgs are not standard. We should set a
94
+ # threshold to convert the nonstandard annotation imgs. The
95
+ # value divided by 128 is equivalent to '1 if value >= 128
96
+ # else 0'
97
+ mmcv.imwrite(
98
+ img[:, :, 0] // 128,
99
+ osp.join(out_dir, 'annotations', 'training',
100
+ osp.splitext(filename)[0] + '.png'))
101
+ for filename in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]:
102
+ img = mmcv.imread(osp.join(tmp_dir, filename))
103
+ mmcv.imwrite(
104
+ img[:, :, 0] // 128,
105
+ osp.join(out_dir, 'annotations', 'validation',
106
+ osp.splitext(filename)[0] + '.png'))
107
+
108
+ print('Done!')
109
+
110
+
111
+ if __name__ == '__main__':
112
+ main()
tools/dataset_converters/isaid.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import argparse
3
+ import glob
4
+ import os
5
+ import os.path as osp
6
+ import shutil
7
+ import tempfile
8
+ import zipfile
9
+
10
+ import mmcv
11
+ import numpy as np
12
+ from mmengine.utils import ProgressBar, mkdir_or_exist
13
+ from PIL import Image
14
+
15
+ iSAID_palette = \
16
+ {
17
+ 0: (0, 0, 0),
18
+ 1: (0, 0, 63),
19
+ 2: (0, 63, 63),
20
+ 3: (0, 63, 0),
21
+ 4: (0, 63, 127),
22
+ 5: (0, 63, 191),
23
+ 6: (0, 63, 255),
24
+ 7: (0, 127, 63),
25
+ 8: (0, 127, 127),
26
+ 9: (0, 0, 127),
27
+ 10: (0, 0, 191),
28
+ 11: (0, 0, 255),
29
+ 12: (0, 191, 127),
30
+ 13: (0, 127, 191),
31
+ 14: (0, 127, 255),
32
+ 15: (0, 100, 155)
33
+ }
34
+
35
+ iSAID_invert_palette = {v: k for k, v in iSAID_palette.items()}
36
+
37
+
38
+ def iSAID_convert_from_color(arr_3d, palette=iSAID_invert_palette):
39
+ """RGB-color encoding to grayscale labels."""
40
+ arr_2d = np.zeros((arr_3d.shape[0], arr_3d.shape[1]), dtype=np.uint8)
41
+
42
+ for c, i in palette.items():
43
+ m = np.all(arr_3d == np.array(c).reshape(1, 1, 3), axis=2)
44
+ arr_2d[m] = i
45
+
46
+ return arr_2d
47
+
48
+
49
+ def slide_crop_image(src_path, out_dir, mode, patch_H, patch_W, overlap):
50
+ img = np.asarray(Image.open(src_path).convert('RGB'))
51
+
52
+ img_H, img_W, _ = img.shape
53
+
54
+ if img_H < patch_H and img_W > patch_W:
55
+
56
+ img = mmcv.impad(img, shape=(patch_H, img_W), pad_val=0)
57
+
58
+ img_H, img_W, _ = img.shape
59
+
60
+ elif img_H > patch_H and img_W < patch_W:
61
+
62
+ img = mmcv.impad(img, shape=(img_H, patch_W), pad_val=0)
63
+
64
+ img_H, img_W, _ = img.shape
65
+
66
+ elif img_H < patch_H and img_W < patch_W:
67
+
68
+ img = mmcv.impad(img, shape=(patch_H, patch_W), pad_val=0)
69
+
70
+ img_H, img_W, _ = img.shape
71
+
72
+ for x in range(0, img_W, patch_W - overlap):
73
+ for y in range(0, img_H, patch_H - overlap):
74
+ x_str = x
75
+ x_end = x + patch_W
76
+ if x_end > img_W:
77
+ diff_x = x_end - img_W
78
+ x_str -= diff_x
79
+ x_end = img_W
80
+ y_str = y
81
+ y_end = y + patch_H
82
+ if y_end > img_H:
83
+ diff_y = y_end - img_H
84
+ y_str -= diff_y
85
+ y_end = img_H
86
+
87
+ img_patch = img[y_str:y_end, x_str:x_end, :]
88
+ img_patch = Image.fromarray(img_patch.astype(np.uint8))
89
+ image = osp.basename(src_path).split('.')[0] + '_' + str(
90
+ y_str) + '_' + str(y_end) + '_' + str(x_str) + '_' + str(
91
+ x_end) + '.png'
92
+ # print(image)
93
+ save_path_image = osp.join(out_dir, 'img_dir', mode, str(image))
94
+ img_patch.save(save_path_image, format='BMP')
95
+
96
+
97
+ def slide_crop_label(src_path, out_dir, mode, patch_H, patch_W, overlap):
98
+ label = mmcv.imread(src_path, channel_order='rgb')
99
+ label = iSAID_convert_from_color(label)
100
+ img_H, img_W = label.shape
101
+
102
+ if img_H < patch_H and img_W > patch_W:
103
+
104
+ label = mmcv.impad(label, shape=(patch_H, img_W), pad_val=255)
105
+
106
+ img_H = patch_H
107
+
108
+ elif img_H > patch_H and img_W < patch_W:
109
+
110
+ label = mmcv.impad(label, shape=(img_H, patch_W), pad_val=255)
111
+
112
+ img_W = patch_W
113
+
114
+ elif img_H < patch_H and img_W < patch_W:
115
+
116
+ label = mmcv.impad(label, shape=(patch_H, patch_W), pad_val=255)
117
+
118
+ img_H = patch_H
119
+ img_W = patch_W
120
+
121
+ for x in range(0, img_W, patch_W - overlap):
122
+ for y in range(0, img_H, patch_H - overlap):
123
+ x_str = x
124
+ x_end = x + patch_W
125
+ if x_end > img_W:
126
+ diff_x = x_end - img_W
127
+ x_str -= diff_x
128
+ x_end = img_W
129
+ y_str = y
130
+ y_end = y + patch_H
131
+ if y_end > img_H:
132
+ diff_y = y_end - img_H
133
+ y_str -= diff_y
134
+ y_end = img_H
135
+
136
+ lab_patch = label[y_str:y_end, x_str:x_end]
137
+ lab_patch = Image.fromarray(lab_patch.astype(np.uint8), mode='P')
138
+
139
+ image = osp.basename(src_path).split('.')[0].split(
140
+ '_')[0] + '_' + str(y_str) + '_' + str(y_end) + '_' + str(
141
+ x_str) + '_' + str(x_end) + '_instance_color_RGB' + '.png'
142
+ lab_patch.save(osp.join(out_dir, 'ann_dir', mode, str(image)))
143
+
144
+
145
+ def parse_args():
146
+ parser = argparse.ArgumentParser(
147
+ description='Convert iSAID dataset to mmsegmentation format')
148
+ parser.add_argument('dataset_path', help='iSAID folder path')
149
+ parser.add_argument('--tmp_dir', help='path of the temporary directory')
150
+ parser.add_argument('-o', '--out_dir', help='output path')
151
+
152
+ parser.add_argument(
153
+ '--patch_width',
154
+ default=896,
155
+ type=int,
156
+ help='Width of the cropped image patch')
157
+ parser.add_argument(
158
+ '--patch_height',
159
+ default=896,
160
+ type=int,
161
+ help='Height of the cropped image patch')
162
+ parser.add_argument(
163
+ '--overlap_area', default=384, type=int, help='Overlap area')
164
+ args = parser.parse_args()
165
+ return args
166
+
167
+
168
+ def main():
169
+ args = parse_args()
170
+ dataset_path = args.dataset_path
171
+ # image patch width and height
172
+ patch_H, patch_W = args.patch_width, args.patch_height
173
+
174
+ overlap = args.overlap_area # overlap area
175
+
176
+ if args.out_dir is None:
177
+ out_dir = osp.join('data', 'iSAID')
178
+ else:
179
+ out_dir = args.out_dir
180
+
181
+ print('Making directories...')
182
+ mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train'))
183
+ mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val'))
184
+ mkdir_or_exist(osp.join(out_dir, 'img_dir', 'test'))
185
+
186
+ mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train'))
187
+ mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val'))
188
+ mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'test'))
189
+
190
+ assert os.path.exists(os.path.join(dataset_path, 'train')), \
191
+ f'train is not in {dataset_path}'
192
+ assert os.path.exists(os.path.join(dataset_path, 'val')), \
193
+ f'val is not in {dataset_path}'
194
+ assert os.path.exists(os.path.join(dataset_path, 'test')), \
195
+ f'test is not in {dataset_path}'
196
+
197
+ with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
198
+ for dataset_mode in ['train', 'val', 'test']:
199
+
200
+ # for dataset_mode in [ 'test']:
201
+ print(f'Extracting {dataset_mode}ing.zip...')
202
+ img_zipp_list = glob.glob(
203
+ os.path.join(dataset_path, dataset_mode, 'images', '*.zip'))
204
+ print('Find the data', img_zipp_list)
205
+ for img_zipp in img_zipp_list:
206
+ zip_file = zipfile.ZipFile(img_zipp)
207
+ zip_file.extractall(os.path.join(tmp_dir, dataset_mode, 'img'))
208
+ src_path_list = glob.glob(
209
+ os.path.join(tmp_dir, dataset_mode, 'img', 'images', '*.png'))
210
+
211
+ src_prog_bar = ProgressBar(len(src_path_list))
212
+ for i, img_path in enumerate(src_path_list):
213
+ if dataset_mode != 'test':
214
+ slide_crop_image(img_path, out_dir, dataset_mode, patch_H,
215
+ patch_W, overlap)
216
+
217
+ else:
218
+ shutil.move(img_path,
219
+ os.path.join(out_dir, 'img_dir', dataset_mode))
220
+ src_prog_bar.update()
221
+
222
+ if dataset_mode != 'test':
223
+ label_zipp_list = glob.glob(
224
+ os.path.join(dataset_path, dataset_mode, 'Semantic_masks',
225
+ '*.zip'))
226
+ for label_zipp in label_zipp_list:
227
+ zip_file = zipfile.ZipFile(label_zipp)
228
+ zip_file.extractall(
229
+ os.path.join(tmp_dir, dataset_mode, 'lab'))
230
+
231
+ lab_path_list = glob.glob(
232
+ os.path.join(tmp_dir, dataset_mode, 'lab', 'images',
233
+ '*.png'))
234
+ lab_prog_bar = ProgressBar(len(lab_path_list))
235
+ for i, lab_path in enumerate(lab_path_list):
236
+ slide_crop_label(lab_path, out_dir, dataset_mode, patch_H,
237
+ patch_W, overlap)
238
+ lab_prog_bar.update()
239
+
240
+ print('Removing the temporary files...')
241
+
242
+ print('Done!')
243
+
244
+
245
+ if __name__ == '__main__':
246
+ main()
tools/dataset_converters/levircd.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import argparse
3
+ import glob
4
+ import math
5
+ import os
6
+ import os.path as osp
7
+
8
+ import mmcv
9
+ import numpy as np
10
+ from mmengine.utils import ProgressBar
11
+
12
+
13
+ def parse_args():
14
+ parser = argparse.ArgumentParser(
15
+ description='Convert levir-cd dataset to mmsegmentation format')
16
+ parser.add_argument('--dataset_path', help='potsdam folder path')
17
+ parser.add_argument('-o', '--out_dir', help='output path')
18
+ parser.add_argument(
19
+ '--clip_size',
20
+ type=int,
21
+ help='clipped size of image after preparation',
22
+ default=256)
23
+ parser.add_argument(
24
+ '--stride_size',
25
+ type=int,
26
+ help='stride of clipping original images',
27
+ default=256)
28
+ args = parser.parse_args()
29
+ return args
30
+
31
+
32
+ def main():
33
+ args = parse_args()
34
+ input_folder = args.dataset_path
35
+ png_files = glob.glob(
36
+ os.path.join(input_folder, '**/*.png'), recursive=True)
37
+ output_folder = args.out_dir
38
+ prog_bar = ProgressBar(len(png_files))
39
+ for png_file in png_files:
40
+ new_path = os.path.join(
41
+ output_folder,
42
+ os.path.relpath(os.path.dirname(png_file), input_folder))
43
+ os.makedirs(os.path.dirname(new_path), exist_ok=True)
44
+ label = False
45
+ if 'label' in png_file:
46
+ label = True
47
+ clip_big_image(png_file, new_path, args, label)
48
+ prog_bar.update()
49
+
50
+
51
+ def clip_big_image(image_path, clip_save_dir, args, to_label=False):
52
+ image = mmcv.imread(image_path)
53
+
54
+ h, w, c = image.shape
55
+ clip_size = args.clip_size
56
+ stride_size = args.stride_size
57
+
58
+ num_rows = math.ceil((h - clip_size) / stride_size) if math.ceil(
59
+ (h - clip_size) /
60
+ stride_size) * stride_size + clip_size >= h else math.ceil(
61
+ (h - clip_size) / stride_size) + 1
62
+ num_cols = math.ceil((w - clip_size) / stride_size) if math.ceil(
63
+ (w - clip_size) /
64
+ stride_size) * stride_size + clip_size >= w else math.ceil(
65
+ (w - clip_size) / stride_size) + 1
66
+
67
+ x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1))
68
+ xmin = x * clip_size
69
+ ymin = y * clip_size
70
+
71
+ xmin = xmin.ravel()
72
+ ymin = ymin.ravel()
73
+ xmin_offset = np.where(xmin + clip_size > w, w - xmin - clip_size,
74
+ np.zeros_like(xmin))
75
+ ymin_offset = np.where(ymin + clip_size > h, h - ymin - clip_size,
76
+ np.zeros_like(ymin))
77
+ boxes = np.stack([
78
+ xmin + xmin_offset, ymin + ymin_offset,
79
+ np.minimum(xmin + clip_size, w),
80
+ np.minimum(ymin + clip_size, h)
81
+ ],
82
+ axis=1)
83
+
84
+ if to_label:
85
+ image[image == 255] = 1
86
+ image = image[:, :, 0]
87
+ for box in boxes:
88
+ start_x, start_y, end_x, end_y = box
89
+ clipped_image = image[start_y:end_y, start_x:end_x] \
90
+ if to_label else image[start_y:end_y, start_x:end_x, :]
91
+ idx = osp.basename(image_path).split('.')[0]
92
+ mmcv.imwrite(
93
+ clipped_image.astype(np.uint8),
94
+ osp.join(clip_save_dir,
95
+ f'{idx}_{start_x}_{start_y}_{end_x}_{end_y}.png'))
96
+
97
+
98
+ if __name__ == '__main__':
99
+ main()
tools/dataset_converters/loveda.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import argparse
3
+ import os
4
+ import os.path as osp
5
+ import shutil
6
+ import tempfile
7
+ import zipfile
8
+
9
+ from mmengine.utils import mkdir_or_exist
10
+
11
+
12
+ def parse_args():
13
+ parser = argparse.ArgumentParser(
14
+ description='Convert LoveDA dataset to mmsegmentation format')
15
+ parser.add_argument('dataset_path', help='LoveDA folder path')
16
+ parser.add_argument('--tmp_dir', help='path of the temporary directory')
17
+ parser.add_argument('-o', '--out_dir', help='output path')
18
+ args = parser.parse_args()
19
+ return args
20
+
21
+
22
+ def main():
23
+ args = parse_args()
24
+ dataset_path = args.dataset_path
25
+ if args.out_dir is None:
26
+ out_dir = osp.join('data', 'loveDA')
27
+ else:
28
+ out_dir = args.out_dir
29
+
30
+ print('Making directories...')
31
+ mkdir_or_exist(out_dir)
32
+ mkdir_or_exist(osp.join(out_dir, 'img_dir'))
33
+ mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train'))
34
+ mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val'))
35
+ mkdir_or_exist(osp.join(out_dir, 'img_dir', 'test'))
36
+ mkdir_or_exist(osp.join(out_dir, 'ann_dir'))
37
+ mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train'))
38
+ mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val'))
39
+
40
+ assert 'Train.zip' in os.listdir(dataset_path), \
41
+ f'Train.zip is not in {dataset_path}'
42
+ assert 'Val.zip' in os.listdir(dataset_path), \
43
+ f'Val.zip is not in {dataset_path}'
44
+ assert 'Test.zip' in os.listdir(dataset_path), \
45
+ f'Test.zip is not in {dataset_path}'
46
+
47
+ with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
48
+ for dataset in ['Train', 'Val', 'Test']:
49
+ zip_file = zipfile.ZipFile(
50
+ os.path.join(dataset_path, dataset + '.zip'))
51
+ zip_file.extractall(tmp_dir)
52
+ data_type = dataset.lower()
53
+ for location in ['Rural', 'Urban']:
54
+ for image_type in ['images_png', 'masks_png']:
55
+ if image_type == 'images_png':
56
+ dst = osp.join(out_dir, 'img_dir', data_type)
57
+ else:
58
+ dst = osp.join(out_dir, 'ann_dir', data_type)
59
+ if dataset == 'Test' and image_type == 'masks_png':
60
+ continue
61
+ else:
62
+ src_dir = osp.join(tmp_dir, dataset, location,
63
+ image_type)
64
+ src_lst = os.listdir(src_dir)
65
+ for file in src_lst:
66
+ shutil.move(osp.join(src_dir, file), dst)
67
+ print('Removing the temporary files...')
68
+
69
+ print('Done!')
70
+
71
+
72
+ if __name__ == '__main__':
73
+ main()
tools/dataset_converters/nyu.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import argparse
3
+ import os.path as osp
4
+ import shutil
5
+ import tempfile
6
+ import zipfile
7
+
8
+ from mmengine.utils import mkdir_or_exist
9
+
10
+
11
+ def parse_args():
12
+ parser = argparse.ArgumentParser(
13
+ description='Convert NYU Depth dataset to mmsegmentation format')
14
+ parser.add_argument('raw_data', help='the path of raw data')
15
+ parser.add_argument(
16
+ '-o', '--out_dir', help='output path', default='./data/nyu')
17
+ args = parser.parse_args()
18
+ return args
19
+
20
+
21
+ def reorganize(raw_data_dir: str, out_dir: str):
22
+ """Reorganize NYU Depth dataset files into the required directory
23
+ structure.
24
+
25
+ Args:
26
+ raw_data_dir (str): Path to the raw data directory.
27
+ out_dir (str): Output directory for the organized dataset.
28
+ """
29
+
30
+ def move_data(data_list, dst_prefix, fname_func):
31
+ """Move data files from source to destination directory.
32
+
33
+ Args:
34
+ data_list (list): List of data file paths.
35
+ dst_prefix (str): Prefix to be added to destination paths.
36
+ fname_func (callable): Function to process file names
37
+ """
38
+ for data_item in data_list:
39
+ data_item = data_item.strip().strip('/')
40
+ new_item = fname_func(data_item)
41
+ shutil.move(
42
+ osp.join(raw_data_dir, data_item),
43
+ osp.join(out_dir, dst_prefix, new_item))
44
+
45
+ def process_phase(phase):
46
+ """Process a dataset phase (e.g., 'train' or 'test')."""
47
+ with open(osp.join(raw_data_dir, f'nyu_{phase}.txt')) as f:
48
+ data = filter(lambda x: len(x.strip()) > 0, f.readlines())
49
+ data = map(lambda x: x.split()[:2], data)
50
+ images, annos = zip(*data)
51
+
52
+ move_data(images, f'images/{phase}',
53
+ lambda x: x.replace('/rgb', ''))
54
+ move_data(annos, f'annotations/{phase}',
55
+ lambda x: x.replace('/sync_depth', ''))
56
+
57
+ process_phase('train')
58
+ process_phase('test')
59
+
60
+
61
+ def main():
62
+ args = parse_args()
63
+
64
+ print('Making directories...')
65
+ mkdir_or_exist(args.out_dir)
66
+ for subdir in [
67
+ 'images/train', 'images/test', 'annotations/train',
68
+ 'annotations/test'
69
+ ]:
70
+ mkdir_or_exist(osp.join(args.out_dir, subdir))
71
+
72
+ print('Generating images and annotations...')
73
+
74
+ if args.raw_data.endswith('.zip'):
75
+ with tempfile.TemporaryDirectory() as tmp_dir:
76
+ zip_file = zipfile.ZipFile(args.raw_data)
77
+ zip_file.extractall(tmp_dir)
78
+ reorganize(osp.join(tmp_dir, 'nyu'), args.out_dir)
79
+ else:
80
+ assert osp.isdir(
81
+ args.raw_data
82
+ ), 'the argument --raw-data should be either a zip file or directory.'
83
+ reorganize(args.raw_data, args.out_dir)
84
+
85
+ print('Done!')
86
+
87
+
88
+ if __name__ == '__main__':
89
+ main()
tools/dataset_converters/pascal_context.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import argparse
3
+ import os.path as osp
4
+ from functools import partial
5
+
6
+ import numpy as np
7
+ from detail import Detail
8
+ from mmengine.utils import mkdir_or_exist, track_progress
9
+ from PIL import Image
10
+
11
+ _mapping = np.sort(
12
+ np.array([
13
+ 0, 2, 259, 260, 415, 324, 9, 258, 144, 18, 19, 22, 23, 397, 25, 284,
14
+ 158, 159, 416, 33, 162, 420, 454, 295, 296, 427, 44, 45, 46, 308, 59,
15
+ 440, 445, 31, 232, 65, 354, 424, 68, 326, 72, 458, 34, 207, 80, 355,
16
+ 85, 347, 220, 349, 360, 98, 187, 104, 105, 366, 189, 368, 113, 115
17
+ ]))
18
+ _key = np.array(range(len(_mapping))).astype('uint8')
19
+
20
+
21
+ def generate_labels(img_id, detail, out_dir):
22
+
23
+ def _class_to_index(mask, _mapping, _key):
24
+ # assert the values
25
+ values = np.unique(mask)
26
+ for i in range(len(values)):
27
+ assert (values[i] in _mapping)
28
+ index = np.digitize(mask.ravel(), _mapping, right=True)
29
+ return _key[index].reshape(mask.shape)
30
+
31
+ mask = Image.fromarray(
32
+ _class_to_index(detail.getMask(img_id), _mapping=_mapping, _key=_key))
33
+ filename = img_id['file_name']
34
+ mask.save(osp.join(out_dir, filename.replace('jpg', 'png')))
35
+ return osp.splitext(osp.basename(filename))[0]
36
+
37
+
38
+ def parse_args():
39
+ parser = argparse.ArgumentParser(
40
+ description='Convert PASCAL VOC annotations to mmsegmentation format')
41
+ parser.add_argument('devkit_path', help='pascal voc devkit path')
42
+ parser.add_argument('json_path', help='annoation json filepath')
43
+ parser.add_argument('-o', '--out_dir', help='output path')
44
+ args = parser.parse_args()
45
+ return args
46
+
47
+
48
+ def main():
49
+ args = parse_args()
50
+ devkit_path = args.devkit_path
51
+ if args.out_dir is None:
52
+ out_dir = osp.join(devkit_path, 'VOC2010', 'SegmentationClassContext')
53
+ else:
54
+ out_dir = args.out_dir
55
+ json_path = args.json_path
56
+ mkdir_or_exist(out_dir)
57
+ img_dir = osp.join(devkit_path, 'VOC2010', 'JPEGImages')
58
+
59
+ train_detail = Detail(json_path, img_dir, 'train')
60
+ train_ids = train_detail.getImgs()
61
+
62
+ val_detail = Detail(json_path, img_dir, 'val')
63
+ val_ids = val_detail.getImgs()
64
+
65
+ mkdir_or_exist(
66
+ osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext'))
67
+
68
+ train_list = track_progress(
69
+ partial(generate_labels, detail=train_detail, out_dir=out_dir),
70
+ train_ids)
71
+ with open(
72
+ osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext',
73
+ 'train.txt'), 'w') as f:
74
+ f.writelines(line + '\n' for line in sorted(train_list))
75
+
76
+ val_list = track_progress(
77
+ partial(generate_labels, detail=val_detail, out_dir=out_dir), val_ids)
78
+ with open(
79
+ osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext',
80
+ 'val.txt'), 'w') as f:
81
+ f.writelines(line + '\n' for line in sorted(val_list))
82
+
83
+ print('Done!')
84
+
85
+
86
+ if __name__ == '__main__':
87
+ main()
tools/dataset_converters/potsdam.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import argparse
3
+ import glob
4
+ import math
5
+ import os
6
+ import os.path as osp
7
+ import tempfile
8
+ import zipfile
9
+
10
+ import mmcv
11
+ import numpy as np
12
+ from mmengine.utils import ProgressBar, mkdir_or_exist
13
+
14
+
15
+ def parse_args():
16
+ parser = argparse.ArgumentParser(
17
+ description='Convert potsdam dataset to mmsegmentation format')
18
+ parser.add_argument('dataset_path', help='potsdam folder path')
19
+ parser.add_argument('--tmp_dir', help='path of the temporary directory')
20
+ parser.add_argument('-o', '--out_dir', help='output path')
21
+ parser.add_argument(
22
+ '--clip_size',
23
+ type=int,
24
+ help='clipped size of image after preparation',
25
+ default=512)
26
+ parser.add_argument(
27
+ '--stride_size',
28
+ type=int,
29
+ help='stride of clipping original images',
30
+ default=256)
31
+ args = parser.parse_args()
32
+ return args
33
+
34
+
35
+ def clip_big_image(image_path, clip_save_dir, args, to_label=False):
36
+ # Original image of Potsdam dataset is very large, thus pre-processing
37
+ # of them is adopted. Given fixed clip size and stride size to generate
38
+ # clipped image, the intersection of width and height is determined.
39
+ # For example, given one 5120 x 5120 original image, the clip size is
40
+ # 512 and stride size is 256, thus it would generate 20x20 = 400 images
41
+ # whose size are all 512x512.
42
+ image = mmcv.imread(image_path)
43
+
44
+ h, w, c = image.shape
45
+ clip_size = args.clip_size
46
+ stride_size = args.stride_size
47
+
48
+ num_rows = math.ceil((h - clip_size) / stride_size) if math.ceil(
49
+ (h - clip_size) /
50
+ stride_size) * stride_size + clip_size >= h else math.ceil(
51
+ (h - clip_size) / stride_size) + 1
52
+ num_cols = math.ceil((w - clip_size) / stride_size) if math.ceil(
53
+ (w - clip_size) /
54
+ stride_size) * stride_size + clip_size >= w else math.ceil(
55
+ (w - clip_size) / stride_size) + 1
56
+
57
+ x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1))
58
+ xmin = x * clip_size
59
+ ymin = y * clip_size
60
+
61
+ xmin = xmin.ravel()
62
+ ymin = ymin.ravel()
63
+ xmin_offset = np.where(xmin + clip_size > w, w - xmin - clip_size,
64
+ np.zeros_like(xmin))
65
+ ymin_offset = np.where(ymin + clip_size > h, h - ymin - clip_size,
66
+ np.zeros_like(ymin))
67
+ boxes = np.stack([
68
+ xmin + xmin_offset, ymin + ymin_offset,
69
+ np.minimum(xmin + clip_size, w),
70
+ np.minimum(ymin + clip_size, h)
71
+ ],
72
+ axis=1)
73
+
74
+ if to_label:
75
+ color_map = np.array([[0, 0, 0], [255, 255, 255], [255, 0, 0],
76
+ [255, 255, 0], [0, 255, 0], [0, 255, 255],
77
+ [0, 0, 255]])
78
+ flatten_v = np.matmul(
79
+ image.reshape(-1, c),
80
+ np.array([2, 3, 4]).reshape(3, 1))
81
+ out = np.zeros_like(flatten_v)
82
+ for idx, class_color in enumerate(color_map):
83
+ value_idx = np.matmul(class_color,
84
+ np.array([2, 3, 4]).reshape(3, 1))
85
+ out[flatten_v == value_idx] = idx
86
+ image = out.reshape(h, w)
87
+
88
+ for box in boxes:
89
+ start_x, start_y, end_x, end_y = box
90
+ clipped_image = image[start_y:end_y,
91
+ start_x:end_x] if to_label else image[
92
+ start_y:end_y, start_x:end_x, :]
93
+ idx_i, idx_j = osp.basename(image_path).split('_')[2:4]
94
+ mmcv.imwrite(
95
+ clipped_image.astype(np.uint8),
96
+ osp.join(
97
+ clip_save_dir,
98
+ f'{idx_i}_{idx_j}_{start_x}_{start_y}_{end_x}_{end_y}.png'))
99
+
100
+
101
+ def main():
102
+ args = parse_args()
103
+ splits = {
104
+ 'train': [
105
+ '2_10', '2_11', '2_12', '3_10', '3_11', '3_12', '4_10', '4_11',
106
+ '4_12', '5_10', '5_11', '5_12', '6_10', '6_11', '6_12', '6_7',
107
+ '6_8', '6_9', '7_10', '7_11', '7_12', '7_7', '7_8', '7_9'
108
+ ],
109
+ 'val': [
110
+ '5_15', '6_15', '6_13', '3_13', '4_14', '6_14', '5_14', '2_13',
111
+ '4_15', '2_14', '5_13', '4_13', '3_14', '7_13'
112
+ ]
113
+ }
114
+
115
+ dataset_path = args.dataset_path
116
+ if args.out_dir is None:
117
+ out_dir = osp.join('data', 'potsdam')
118
+ else:
119
+ out_dir = args.out_dir
120
+
121
+ print('Making directories...')
122
+ mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train'))
123
+ mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val'))
124
+ mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train'))
125
+ mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val'))
126
+
127
+ zipp_list = glob.glob(os.path.join(dataset_path, '*.zip'))
128
+ print('Find the data', zipp_list)
129
+
130
+ for zipp in zipp_list:
131
+ with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
132
+ zip_file = zipfile.ZipFile(zipp)
133
+ zip_file.extractall(tmp_dir)
134
+ src_path_list = glob.glob(os.path.join(tmp_dir, '*.tif'))
135
+ if not len(src_path_list):
136
+ sub_tmp_dir = os.path.join(tmp_dir, os.listdir(tmp_dir)[0])
137
+ src_path_list = glob.glob(os.path.join(sub_tmp_dir, '*.tif'))
138
+
139
+ prog_bar = ProgressBar(len(src_path_list))
140
+ for i, src_path in enumerate(src_path_list):
141
+ idx_i, idx_j = osp.basename(src_path).split('_')[2:4]
142
+ data_type = 'train' if f'{idx_i}_{idx_j}' in splits[
143
+ 'train'] else 'val'
144
+ if 'label' in src_path:
145
+ dst_dir = osp.join(out_dir, 'ann_dir', data_type)
146
+ clip_big_image(src_path, dst_dir, args, to_label=True)
147
+ else:
148
+ dst_dir = osp.join(out_dir, 'img_dir', data_type)
149
+ clip_big_image(src_path, dst_dir, args, to_label=False)
150
+ prog_bar.update()
151
+
152
+ print('Removing the temporary files...')
153
+
154
+ print('Done!')
155
+
156
+
157
+ if __name__ == '__main__':
158
+ main()
tools/dataset_converters/refuge.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import argparse
3
+ import os
4
+ import os.path as osp
5
+ import tempfile
6
+ import zipfile
7
+
8
+ import mmcv
9
+ import numpy as np
10
+ from mmengine.utils import mkdir_or_exist
11
+
12
+
13
+ def parse_args():
14
+ parser = argparse.ArgumentParser(
15
+ description='Convert REFUGE dataset to mmsegmentation format')
16
+ parser.add_argument('--raw_data_root', help='the root path of raw data')
17
+
18
+ parser.add_argument('--tmp_dir', help='path of the temporary directory')
19
+ parser.add_argument('-o', '--out_dir', help='output path')
20
+ args = parser.parse_args()
21
+ return args
22
+
23
+
24
+ def extract_img(root: str,
25
+ cur_dir: str,
26
+ out_dir: str,
27
+ mode: str = 'train',
28
+ file_type: str = 'img') -> None:
29
+ """_summary_
30
+
31
+ Args:
32
+ Args:
33
+ root (str): root where the extracted data is saved
34
+ cur_dir (cur_dir): dir where the zip_file exists
35
+ out_dir (str): root dir where the data is saved
36
+
37
+ mode (str, optional): Defaults to 'train'.
38
+ file_type (str, optional): Defaults to 'img',else to 'mask'.
39
+ """
40
+ zip_file = zipfile.ZipFile(cur_dir)
41
+ zip_file.extractall(root)
42
+ for cur_dir, dirs, files in os.walk(root):
43
+ # filter child dirs and directories with "Illustration" and "MACOSX"
44
+ if len(dirs) == 0 and \
45
+ cur_dir.split('\\')[-1].find('Illustration') == -1 and \
46
+ cur_dir.find('MACOSX') == -1:
47
+
48
+ file_names = [
49
+ file for file in files
50
+ if file.endswith('.jpg') or file.endswith('.bmp')
51
+ ]
52
+ for filename in sorted(file_names):
53
+ img = mmcv.imread(osp.join(cur_dir, filename))
54
+
55
+ if file_type == 'annotations':
56
+ img = img[:, :, 0]
57
+ img[np.where(img == 0)] = 1
58
+ img[np.where(img == 128)] = 2
59
+ img[np.where(img == 255)] = 0
60
+ mmcv.imwrite(
61
+ img,
62
+ osp.join(out_dir, file_type, mode,
63
+ osp.splitext(filename)[0] + '.png'))
64
+
65
+
66
+ def main():
67
+ args = parse_args()
68
+
69
+ raw_data_root = args.raw_data_root
70
+ if args.out_dir is None:
71
+ out_dir = osp.join('./data', 'REFUGE')
72
+
73
+ else:
74
+ out_dir = args.out_dir
75
+
76
+ print('Making directories...')
77
+ mkdir_or_exist(out_dir)
78
+ mkdir_or_exist(osp.join(out_dir, 'images'))
79
+ mkdir_or_exist(osp.join(out_dir, 'images', 'training'))
80
+ mkdir_or_exist(osp.join(out_dir, 'images', 'validation'))
81
+ mkdir_or_exist(osp.join(out_dir, 'images', 'test'))
82
+ mkdir_or_exist(osp.join(out_dir, 'annotations'))
83
+ mkdir_or_exist(osp.join(out_dir, 'annotations', 'training'))
84
+ mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation'))
85
+ mkdir_or_exist(osp.join(out_dir, 'annotations', 'test'))
86
+
87
+ print('Generating images and annotations...')
88
+ # process data from the child dir on the first rank
89
+ cur_dir, dirs, files = list(os.walk(raw_data_root))[0]
90
+ print('====================')
91
+
92
+ files = list(filter(lambda x: x.endswith('.zip'), files))
93
+
94
+ with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
95
+ for file in files:
96
+ # search data folders for training,validation,test
97
+ mode = list(
98
+ filter(lambda x: file.lower().find(x) != -1,
99
+ ['training', 'test', 'validation']))[0]
100
+ file_root = osp.join(tmp_dir, file[:-4])
101
+ file_type = 'images' if file.find('Anno') == -1 and file.find(
102
+ 'GT') == -1 else 'annotations'
103
+ extract_img(file_root, osp.join(cur_dir, file), out_dir, mode,
104
+ file_type)
105
+
106
+ print('Done!')
107
+
108
+
109
+ if __name__ == '__main__':
110
+ main()
tools/dataset_converters/stare.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import argparse
3
+ import gzip
4
+ import os
5
+ import os.path as osp
6
+ import tarfile
7
+ import tempfile
8
+
9
+ import mmcv
10
+ from mmengine.utils import mkdir_or_exist
11
+
12
+ STARE_LEN = 20
13
+ TRAINING_LEN = 10
14
+
15
+
16
+ def un_gz(src, dst):
17
+ g_file = gzip.GzipFile(src)
18
+ with open(dst, 'wb+') as f:
19
+ f.write(g_file.read())
20
+ g_file.close()
21
+
22
+
23
+ def parse_args():
24
+ parser = argparse.ArgumentParser(
25
+ description='Convert STARE dataset to mmsegmentation format')
26
+ parser.add_argument('image_path', help='the path of stare-images.tar')
27
+ parser.add_argument('labels_ah', help='the path of labels-ah.tar')
28
+ parser.add_argument('labels_vk', help='the path of labels-vk.tar')
29
+ parser.add_argument('--tmp_dir', help='path of the temporary directory')
30
+ parser.add_argument('-o', '--out_dir', help='output path')
31
+ args = parser.parse_args()
32
+ return args
33
+
34
+
35
+ def main():
36
+ args = parse_args()
37
+ image_path = args.image_path
38
+ labels_ah = args.labels_ah
39
+ labels_vk = args.labels_vk
40
+ if args.out_dir is None:
41
+ out_dir = osp.join('data', 'STARE')
42
+ else:
43
+ out_dir = args.out_dir
44
+
45
+ print('Making directories...')
46
+ mkdir_or_exist(out_dir)
47
+ mkdir_or_exist(osp.join(out_dir, 'images'))
48
+ mkdir_or_exist(osp.join(out_dir, 'images', 'training'))
49
+ mkdir_or_exist(osp.join(out_dir, 'images', 'validation'))
50
+ mkdir_or_exist(osp.join(out_dir, 'annotations'))
51
+ mkdir_or_exist(osp.join(out_dir, 'annotations', 'training'))
52
+ mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation'))
53
+
54
+ with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
55
+ mkdir_or_exist(osp.join(tmp_dir, 'gz'))
56
+ mkdir_or_exist(osp.join(tmp_dir, 'files'))
57
+
58
+ print('Extracting stare-images.tar...')
59
+ with tarfile.open(image_path) as f:
60
+ f.extractall(osp.join(tmp_dir, 'gz'))
61
+
62
+ for filename in os.listdir(osp.join(tmp_dir, 'gz')):
63
+ un_gz(
64
+ osp.join(tmp_dir, 'gz', filename),
65
+ osp.join(tmp_dir, 'files',
66
+ osp.splitext(filename)[0]))
67
+
68
+ now_dir = osp.join(tmp_dir, 'files')
69
+
70
+ assert len(os.listdir(now_dir)) == STARE_LEN, \
71
+ f'len(os.listdir(now_dir)) != {STARE_LEN}'
72
+
73
+ for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]:
74
+ img = mmcv.imread(osp.join(now_dir, filename))
75
+ mmcv.imwrite(
76
+ img,
77
+ osp.join(out_dir, 'images', 'training',
78
+ osp.splitext(filename)[0] + '.png'))
79
+
80
+ for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]:
81
+ img = mmcv.imread(osp.join(now_dir, filename))
82
+ mmcv.imwrite(
83
+ img,
84
+ osp.join(out_dir, 'images', 'validation',
85
+ osp.splitext(filename)[0] + '.png'))
86
+
87
+ print('Removing the temporary files...')
88
+
89
+ with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
90
+ mkdir_or_exist(osp.join(tmp_dir, 'gz'))
91
+ mkdir_or_exist(osp.join(tmp_dir, 'files'))
92
+
93
+ print('Extracting labels-ah.tar...')
94
+ with tarfile.open(labels_ah) as f:
95
+ f.extractall(osp.join(tmp_dir, 'gz'))
96
+
97
+ for filename in os.listdir(osp.join(tmp_dir, 'gz')):
98
+ un_gz(
99
+ osp.join(tmp_dir, 'gz', filename),
100
+ osp.join(tmp_dir, 'files',
101
+ osp.splitext(filename)[0]))
102
+
103
+ now_dir = osp.join(tmp_dir, 'files')
104
+
105
+ assert len(os.listdir(now_dir)) == STARE_LEN, \
106
+ f'len(os.listdir(now_dir)) != {STARE_LEN}'
107
+
108
+ for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]:
109
+ img = mmcv.imread(osp.join(now_dir, filename))
110
+ # The annotation img should be divided by 128, because some of
111
+ # the annotation imgs are not standard. We should set a threshold
112
+ # to convert the nonstandard annotation imgs. The value divided by
113
+ # 128 equivalent to '1 if value >= 128 else 0'
114
+ mmcv.imwrite(
115
+ img[:, :, 0] // 128,
116
+ osp.join(out_dir, 'annotations', 'training',
117
+ osp.splitext(filename)[0] + '.png'))
118
+
119
+ for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]:
120
+ img = mmcv.imread(osp.join(now_dir, filename))
121
+ mmcv.imwrite(
122
+ img[:, :, 0] // 128,
123
+ osp.join(out_dir, 'annotations', 'validation',
124
+ osp.splitext(filename)[0] + '.png'))
125
+
126
+ print('Removing the temporary files...')
127
+
128
+ with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
129
+ mkdir_or_exist(osp.join(tmp_dir, 'gz'))
130
+ mkdir_or_exist(osp.join(tmp_dir, 'files'))
131
+
132
+ print('Extracting labels-vk.tar...')
133
+ with tarfile.open(labels_vk) as f:
134
+ f.extractall(osp.join(tmp_dir, 'gz'))
135
+
136
+ for filename in os.listdir(osp.join(tmp_dir, 'gz')):
137
+ un_gz(
138
+ osp.join(tmp_dir, 'gz', filename),
139
+ osp.join(tmp_dir, 'files',
140
+ osp.splitext(filename)[0]))
141
+
142
+ now_dir = osp.join(tmp_dir, 'files')
143
+
144
+ assert len(os.listdir(now_dir)) == STARE_LEN, \
145
+ f'len(os.listdir(now_dir)) != {STARE_LEN}'
146
+
147
+ for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]:
148
+ img = mmcv.imread(osp.join(now_dir, filename))
149
+ mmcv.imwrite(
150
+ img[:, :, 0] // 128,
151
+ osp.join(out_dir, 'annotations', 'training',
152
+ osp.splitext(filename)[0] + '.png'))
153
+
154
+ for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]:
155
+ img = mmcv.imread(osp.join(now_dir, filename))
156
+ mmcv.imwrite(
157
+ img[:, :, 0] // 128,
158
+ osp.join(out_dir, 'annotations', 'validation',
159
+ osp.splitext(filename)[0] + '.png'))
160
+
161
+ print('Removing the temporary files...')
162
+
163
+ print('Done!')
164
+
165
+
166
+ if __name__ == '__main__':
167
+ main()
tools/dataset_converters/synapse.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import argparse
3
+ import os.path as osp
4
+
5
+ import nibabel as nib
6
+ import numpy as np
7
+ from mmengine.utils import mkdir_or_exist
8
+ from PIL import Image
9
+
10
+
11
+ def read_files_from_txt(txt_path):
12
+ with open(txt_path) as f:
13
+ files = f.readlines()
14
+ files = [file.strip() for file in files]
15
+ return files
16
+
17
+
18
+ def read_nii_file(nii_path):
19
+ img = nib.load(nii_path).get_fdata()
20
+ return img
21
+
22
+
23
+ def split_3d_image(img):
24
+ c, _, _ = img.shape
25
+ res = []
26
+ for i in range(c):
27
+ res.append(img[i, :, :])
28
+ return res
29
+
30
+
31
+ def label_mapping(label):
32
+ """Label mapping from TransUNet paper setting. It only has 9 classes, which
33
+ are 'background', 'aorta', 'gallbladder', 'left_kidney', 'right_kidney',
34
+ 'liver', 'pancreas', 'spleen', 'stomach', respectively. Other foreground
35
+ classes in original dataset are all set to background.
36
+
37
+ More details could be found here: https://arxiv.org/abs/2102.04306
38
+ """
39
+ maped_label = np.zeros_like(label)
40
+ maped_label[label == 8] = 1
41
+ maped_label[label == 4] = 2
42
+ maped_label[label == 3] = 3
43
+ maped_label[label == 2] = 4
44
+ maped_label[label == 6] = 5
45
+ maped_label[label == 11] = 6
46
+ maped_label[label == 1] = 7
47
+ maped_label[label == 7] = 8
48
+ return maped_label
49
+
50
+
51
+ def pares_args():
52
+ parser = argparse.ArgumentParser(
53
+ description='Convert synapse dataset to mmsegmentation format')
54
+ parser.add_argument(
55
+ '--dataset-path', type=str, help='synapse dataset path.')
56
+ parser.add_argument(
57
+ '--save-path',
58
+ default='data/synapse',
59
+ type=str,
60
+ help='save path of the dataset.')
61
+ args = parser.parse_args()
62
+ return args
63
+
64
+
65
+ def main():
66
+ args = pares_args()
67
+ dataset_path = args.dataset_path
68
+ save_path = args.save_path
69
+
70
+ if not osp.exists(dataset_path):
71
+ raise ValueError('The dataset path does not exist. '
72
+ 'Please enter a correct dataset path.')
73
+ if not osp.exists(osp.join(dataset_path, 'img')) \
74
+ or not osp.exists(osp.join(dataset_path, 'label')):
75
+ raise FileNotFoundError('The dataset structure is incorrect. '
76
+ 'Please check your dataset.')
77
+
78
+ train_id = read_files_from_txt(osp.join(dataset_path, 'train.txt'))
79
+ train_id = [idx[3:7] for idx in train_id]
80
+
81
+ test_id = read_files_from_txt(osp.join(dataset_path, 'val.txt'))
82
+ test_id = [idx[3:7] for idx in test_id]
83
+
84
+ mkdir_or_exist(osp.join(save_path, 'img_dir/train'))
85
+ mkdir_or_exist(osp.join(save_path, 'img_dir/val'))
86
+ mkdir_or_exist(osp.join(save_path, 'ann_dir/train'))
87
+ mkdir_or_exist(osp.join(save_path, 'ann_dir/val'))
88
+
89
+ # It follows data preparation pipeline from here:
90
+ # https://github.com/Beckschen/TransUNet/tree/main/datasets
91
+ for i, idx in enumerate(train_id):
92
+ img_3d = read_nii_file(
93
+ osp.join(dataset_path, 'img', 'img' + idx + '.nii.gz'))
94
+ label_3d = read_nii_file(
95
+ osp.join(dataset_path, 'label', 'label' + idx + '.nii.gz'))
96
+
97
+ img_3d = np.clip(img_3d, -125, 275)
98
+ img_3d = (img_3d + 125) / 400
99
+ img_3d *= 255
100
+ img_3d = np.transpose(img_3d, [2, 0, 1])
101
+ img_3d = np.flip(img_3d, 2)
102
+
103
+ label_3d = np.transpose(label_3d, [2, 0, 1])
104
+ label_3d = np.flip(label_3d, 2)
105
+ label_3d = label_mapping(label_3d)
106
+
107
+ for c in range(img_3d.shape[0]):
108
+ img = img_3d[c]
109
+ label = label_3d[c]
110
+
111
+ img = Image.fromarray(img).convert('RGB')
112
+ label = Image.fromarray(label).convert('L')
113
+ img.save(
114
+ osp.join(
115
+ save_path, 'img_dir/train', 'case' + idx.zfill(4) +
116
+ '_slice' + str(c).zfill(3) + '.jpg'))
117
+ label.save(
118
+ osp.join(
119
+ save_path, 'ann_dir/train', 'case' + idx.zfill(4) +
120
+ '_slice' + str(c).zfill(3) + '.png'))
121
+
122
+ for i, idx in enumerate(test_id):
123
+ img_3d = read_nii_file(
124
+ osp.join(dataset_path, 'img', 'img' + idx + '.nii.gz'))
125
+ label_3d = read_nii_file(
126
+ osp.join(dataset_path, 'label', 'label' + idx + '.nii.gz'))
127
+
128
+ img_3d = np.clip(img_3d, -125, 275)
129
+ img_3d = (img_3d + 125) / 400
130
+ img_3d *= 255
131
+ img_3d = np.transpose(img_3d, [2, 0, 1])
132
+ img_3d = np.flip(img_3d, 2)
133
+
134
+ label_3d = np.transpose(label_3d, [2, 0, 1])
135
+ label_3d = np.flip(label_3d, 2)
136
+ label_3d = label_mapping(label_3d)
137
+
138
+ for c in range(img_3d.shape[0]):
139
+ img = img_3d[c]
140
+ label = label_3d[c]
141
+
142
+ img = Image.fromarray(img).convert('RGB')
143
+ label = Image.fromarray(label).convert('L')
144
+ img.save(
145
+ osp.join(
146
+ save_path, 'img_dir/val', 'case' + idx.zfill(4) +
147
+ '_slice' + str(c).zfill(3) + '.jpg'))
148
+ label.save(
149
+ osp.join(
150
+ save_path, 'ann_dir/val', 'case' + idx.zfill(4) +
151
+ '_slice' + str(c).zfill(3) + '.png'))
152
+
153
+
154
+ if __name__ == '__main__':
155
+ main()
tools/dataset_converters/voc_aug.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import argparse
3
+ import os.path as osp
4
+ from functools import partial
5
+
6
+ import numpy as np
7
+ from mmengine.utils import mkdir_or_exist, scandir, track_parallel_progress
8
+ from PIL import Image
9
+ from scipy.io import loadmat
10
+
11
+ AUG_LEN = 10582
12
+
13
+
14
+ def convert_mat(mat_file, in_dir, out_dir):
15
+ data = loadmat(osp.join(in_dir, mat_file))
16
+ mask = data['GTcls'][0]['Segmentation'][0].astype(np.uint8)
17
+ seg_filename = osp.join(out_dir, mat_file.replace('.mat', '.png'))
18
+ Image.fromarray(mask).save(seg_filename, 'PNG')
19
+
20
+
21
+ def generate_aug_list(merged_list, excluded_list):
22
+ return list(set(merged_list) - set(excluded_list))
23
+
24
+
25
+ def parse_args():
26
+ parser = argparse.ArgumentParser(
27
+ description='Convert PASCAL VOC annotations to mmsegmentation format')
28
+ parser.add_argument('devkit_path', help='pascal voc devkit path')
29
+ parser.add_argument('aug_path', help='pascal voc aug path')
30
+ parser.add_argument('-o', '--out_dir', help='output path')
31
+ parser.add_argument(
32
+ '--nproc', default=1, type=int, help='number of process')
33
+ args = parser.parse_args()
34
+ return args
35
+
36
+
37
+ def main():
38
+ args = parse_args()
39
+ devkit_path = args.devkit_path
40
+ aug_path = args.aug_path
41
+ nproc = args.nproc
42
+ if args.out_dir is None:
43
+ out_dir = osp.join(devkit_path, 'VOC2012', 'SegmentationClassAug')
44
+ else:
45
+ out_dir = args.out_dir
46
+ mkdir_or_exist(out_dir)
47
+ in_dir = osp.join(aug_path, 'dataset', 'cls')
48
+
49
+ track_parallel_progress(
50
+ partial(convert_mat, in_dir=in_dir, out_dir=out_dir),
51
+ list(scandir(in_dir, suffix='.mat')),
52
+ nproc=nproc)
53
+
54
+ full_aug_list = []
55
+ with open(osp.join(aug_path, 'dataset', 'train.txt')) as f:
56
+ full_aug_list += [line.strip() for line in f]
57
+ with open(osp.join(aug_path, 'dataset', 'val.txt')) as f:
58
+ full_aug_list += [line.strip() for line in f]
59
+
60
+ with open(
61
+ osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
62
+ 'train.txt')) as f:
63
+ ori_train_list = [line.strip() for line in f]
64
+ with open(
65
+ osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
66
+ 'val.txt')) as f:
67
+ val_list = [line.strip() for line in f]
68
+
69
+ aug_train_list = generate_aug_list(ori_train_list + full_aug_list,
70
+ val_list)
71
+ assert len(aug_train_list) == AUG_LEN, 'len(aug_train_list) != {}'.format(
72
+ AUG_LEN)
73
+
74
+ with open(
75
+ osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
76
+ 'trainaug.txt'), 'w') as f:
77
+ f.writelines(line + '\n' for line in aug_train_list)
78
+
79
+ aug_list = generate_aug_list(full_aug_list, ori_train_list + val_list)
80
+ assert len(aug_list) == AUG_LEN - len(
81
+ ori_train_list), 'len(aug_list) != {}'.format(AUG_LEN -
82
+ len(ori_train_list))
83
+ with open(
84
+ osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', 'aug.txt'),
85
+ 'w') as f:
86
+ f.writelines(line + '\n' for line in aug_list)
87
+
88
+ print('Done!')
89
+
90
+
91
+ if __name__ == '__main__':
92
+ main()
tools/dataset_tools/create_dataset.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from glob import glob
3
+ from typing import List, Literal
4
+ import shutil
5
+ from PIL import Image
6
+ import json
7
+ import numpy as np
8
+ from rich.progress import track
9
+ import cv2
10
+ from vegseg.datasets import GrassDataset
11
+ from sklearn.model_selection import train_test_split
12
+ import argparse
13
+
14
+
15
+ def give_color_to_mask(mask: np.ndarray, palette: List[int]) -> Image.Image:
16
+ """
17
+ Convert mask to color image
18
+ Args:
19
+ mask (np.ndarray): numpy array of shape (H, W)
20
+ palette (List[int]): list of RGB values
21
+ return:
22
+ color_mask (Image.Image): PIL Image of shape (H, W)
23
+ """
24
+ im = Image.fromarray(mask).convert("P")
25
+ im.putpalette(palette)
26
+ # exit(0)
27
+ return im
28
+
29
+
30
+ def get_mask_by_json(filename: str) -> np.ndarray:
31
+ """
32
+ Convert json to mask
33
+ Args:
34
+ filename (str): path to json file
35
+ return:
36
+ mask (np.ndarray): numpy array of shape (H, W)
37
+ """
38
+
39
+ json_file = json.load(open(filename))
40
+ img_height = json_file["imageHeight"]
41
+ img_width = json_file["imageWidth"]
42
+ mask = np.zeros((img_height, img_width), dtype="int8")
43
+ for shape in json_file["shapes"]:
44
+ label = int(shape["label"])
45
+ label -= 1
46
+ label = max(label, 0)
47
+ points = np.array(shape["points"]).astype(np.int32)
48
+ cv2.fillPoly(mask, [points], label)
49
+ return mask
50
+
51
+
52
+ def json_to_image(json_path, image_path):
53
+ """
54
+ Convert json to image
55
+ Args:
56
+ json_path (str): path to json file
57
+ image_path (str): path to save image
58
+ return: None
59
+ """
60
+ mask = get_mask_by_json(json_path)
61
+ palette_list = GrassDataset.METAINFO["palette"]
62
+ palette = []
63
+ for palette_item in palette_list:
64
+ palette.extend(palette_item)
65
+ color_mask = give_color_to_mask(mask, palette)
66
+ color_mask.save(image_path)
67
+
68
+
69
+ def create_dataset(
70
+ image_paths: List[str],
71
+ ann_paths: List[str],
72
+ phase: Literal["train", "val"],
73
+ output_dir: str,
74
+ ):
75
+ """
76
+ Args:
77
+ image_paths (List[str]): list of image paths
78
+ ann_paths (List[str]): list of annotation paths
79
+ phase (Literal["train", "val"]): train or val
80
+ output_dir (str): path to save dataset
81
+ Return:
82
+ None
83
+ """
84
+ for image_path, ann_path in track(
85
+ zip(image_paths, ann_paths),
86
+ description=f"{phase} dataset",
87
+ total=len(image_paths),
88
+ ):
89
+ ann_save_path = os.path.join(
90
+ output_dir,
91
+ "ann_dir",
92
+ phase,
93
+ os.path.basename(ann_path).replace(".json", ".png"),
94
+ )
95
+
96
+ # 将image复制到指定路径
97
+ new_image_path = os.path.join(
98
+ output_dir, "img_dir", phase, os.path.basename(image_path)
99
+ )
100
+ shutil.copy(image_path, new_image_path)
101
+
102
+ # 将ann保存到指定路径
103
+ json_to_image(ann_path, ann_save_path)
104
+
105
+
106
+ def split_dataset(
107
+ root_path: str,
108
+ output_path: str,
109
+ split_ratio: float = 0.8,
110
+ shuffle: bool = True,
111
+ seed: int = 42,
112
+ ) -> None:
113
+ """
114
+ Split a dataset into train, test, and validation sets.
115
+
116
+ Args:
117
+ root_path (str): Path to the dataset. The dataset should be organized as follows:
118
+ dataset_path/
119
+ image1.tif
120
+ image2.tif
121
+ ...
122
+ imageN.tif
123
+ label1.tif
124
+ label2.tif
125
+ ...
126
+ labelN.tif
127
+ output_path (str): Path to the output directory where the split dataset will be saved.
128
+ split_ratio (float, optional): Ratio of the dataset to be used for training. Defaults to 0.8.
129
+ seed (int, optional): Seed for the random number generator. Defaults to 42.
130
+ """
131
+ image_paths = glob(os.path.join(root_path, "*.tif"))
132
+ ann_paths = [filename.replace("tif", "json") for filename in image_paths]
133
+ assert len(image_paths) == len(
134
+ ann_paths
135
+ ), "Number of images and annotations do not match"
136
+ print(f"images: {len(image_paths)}, annotations: {len(ann_paths)}")
137
+
138
+ image_train, image_test, ann_train, ann_test = train_test_split(
139
+ image_paths,
140
+ ann_paths,
141
+ train_size=split_ratio,
142
+ random_state=seed,
143
+ shuffle=shuffle,
144
+ )
145
+ print(f"train: {len(image_train)}, test: {len(image_test)}")
146
+
147
+ os.makedirs(os.path.join(output_path, "img_dir", "train"), exist_ok=True)
148
+ os.makedirs(os.path.join(output_path, "img_dir", "val"), exist_ok=True)
149
+ os.makedirs(os.path.join(output_path, "ann_dir", "train"), exist_ok=True)
150
+ os.makedirs(os.path.join(output_path, "ann_dir", "val"), exist_ok=True)
151
+
152
+ create_dataset(image_train, ann_train, "train", output_path)
153
+ create_dataset(image_test, ann_test, "val", output_path)
154
+
155
+
156
+ def main():
157
+ args = argparse.ArgumentParser()
158
+ args.add_argument("--root", type=str, default="data/raw_data")
159
+ args.add_argument("--output", type=str, default="data/grass")
160
+ args.add_argument("--split_ratio", type=float, default=0.8)
161
+ args.add_argument("--seed", type=int, default=42)
162
+ args.add_argument("--shuffle", type=bool, default=True)
163
+ args = args.parse_args()
164
+
165
+ root: str = args.root
166
+ output_path: str = args.output
167
+ split_ratio: float = args.split_ratio
168
+ seed: int = args.seed
169
+ shuffle: bool = args.shuffle
170
+
171
+ split_dataset(
172
+ root_path=root,
173
+ output_path=output_path,
174
+ split_ratio=split_ratio,
175
+ shuffle=shuffle,
176
+ seed=seed,
177
+ )
178
+
179
+ print("数据集划分完成")
180
+
181
+
182
+ if __name__ == "__main__":
183
+
184
+ # 使用示例 : python src/tools/split_dataset.py --root data/raw_data --output data/grass --split_ratio 0.8 --seed 42 --shuffle True
185
+ main()