chenz53 commited on
Commit
2069947
·
verified ·
1 Parent(s): 816ac34

Upload 2 files

Browse files
Files changed (2) hide show
  1. dataload.py +235 -0
  2. tumor_mask.py +84 -0
dataload.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Optional, Sequence
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.distributed as ptdist
7
+ from monai.data import (
8
+ CacheDataset,
9
+ PersistentDataset,
10
+ partition_dataset,
11
+ )
12
+ from monai.data.utils import pad_list_data_collate
13
+ from monai.transforms import (
14
+ Compose,
15
+ CropForegroundd,
16
+ EnsureChannelFirstd,
17
+ LoadImaged,
18
+ Orientationd,
19
+ RandSpatialCropSamplesd,
20
+ ScaleIntensityRanged,
21
+ Spacingd,
22
+ SpatialPadd,
23
+ ToTensord,
24
+ Transform,
25
+ )
26
+
27
+
28
+ class PermuteImage(Transform):
29
+ """Permute the dimensions of the image"""
30
+
31
+ def __call__(self, data):
32
+ data["image"] = data["image"].permute(
33
+ 3, 0, 1, 2
34
+ ) # Adjust permutation order as needed
35
+ return data
36
+
37
+
38
+ class CTDataset:
39
+ def __init__(
40
+ self,
41
+ json_path: str,
42
+ img_size: int,
43
+ depth: int,
44
+ mask_patch_size: int,
45
+ patch_size: int,
46
+ downsample_ratio: Sequence[float],
47
+ cache_dir: str,
48
+ batch_size: int = 1,
49
+ val_batch_size: int = 1,
50
+ num_workers: int = 4,
51
+ cache_num: int = 0,
52
+ cache_rate: float = 0.0,
53
+ dist: bool = False,
54
+ ):
55
+ super().__init__()
56
+ self.json_path = json_path
57
+ self.img_size = img_size
58
+ self.depth = depth
59
+ self.mask_patch_size = mask_patch_size
60
+ self.patch_size = patch_size
61
+ self.cache_dir = cache_dir
62
+ self.downsample_ratio = downsample_ratio
63
+ self.batch_size = batch_size
64
+ self.val_batch_size = val_batch_size
65
+ self.num_workers = num_workers
66
+ self.cache_num = cache_num
67
+ self.cache_rate = cache_rate
68
+ self.dist = dist
69
+
70
+ data_list = json.load(open(json_path, "r"))
71
+
72
+ if "train" in data_list.keys():
73
+ self.train_list = data_list["train"]
74
+ if "validation" in data_list.keys():
75
+ self.val_list = data_list["validation"]
76
+
77
+ def val_transforms(
78
+ self,
79
+ ):
80
+ return self.train_transforms()
81
+
82
+ def train_transforms(
83
+ self,
84
+ ):
85
+ transforms = Compose(
86
+ [
87
+ LoadImaged(keys=["image"]),
88
+ EnsureChannelFirstd(keys=["image"]),
89
+ Orientationd(keys=["image"], axcodes="RAS"),
90
+ Spacingd(
91
+ keys=["image"],
92
+ pixdim=self.downsample_ratio,
93
+ mode=("bilinear"),
94
+ ),
95
+ ScaleIntensityRanged(
96
+ keys=["image"],
97
+ a_min=-175,
98
+ a_max=250,
99
+ b_min=0.0,
100
+ b_max=1.0,
101
+ clip=True,
102
+ ),
103
+ CropForegroundd(keys=["image"], source_key="image"),
104
+ RandSpatialCropSamplesd(
105
+ keys=["image"],
106
+ roi_size=(self.img_size, self.img_size, self.depth),
107
+ random_size=False,
108
+ num_samples=1,
109
+ ),
110
+ SpatialPadd(
111
+ keys=["image"],
112
+ spatial_size=(self.img_size, self.img_size, self.depth),
113
+ ),
114
+ # RandScaleIntensityd(keys="image", factors=0.1, prob=0.5),
115
+ # RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
116
+ ToTensord(keys=["image"]),
117
+ PermuteImage(),
118
+ ]
119
+ )
120
+
121
+ return transforms
122
+
123
+ def setup(self, stage: Optional[str] = None):
124
+ # Assign Train split(s) for use in Dataloaders
125
+ if stage in [None, "train"]:
126
+ if self.dist:
127
+ train_partition = partition_dataset(
128
+ data=self.train_list,
129
+ num_partitions=ptdist.get_world_size(),
130
+ shuffle=True,
131
+ even_divisible=True,
132
+ drop_last=False,
133
+ )[ptdist.get_rank()]
134
+ valid_partition = partition_dataset(
135
+ data=self.val_list,
136
+ num_partitions=ptdist.get_world_size(),
137
+ shuffle=False,
138
+ even_divisible=True,
139
+ drop_last=False,
140
+ )[ptdist.get_rank()]
141
+ # self.cache_num //= ptdist.get_world_size()
142
+ else:
143
+ train_partition = self.train_list
144
+ valid_partition = self.val_list
145
+
146
+ if any([self.cache_num, self.cache_rate]) > 0:
147
+ train_ds = CacheDataset(
148
+ train_partition,
149
+ cache_num=self.cache_num,
150
+ cache_rate=self.cache_rate,
151
+ num_workers=self.num_workers,
152
+ transform=self.train_transforms(),
153
+ )
154
+ valid_ds = CacheDataset(
155
+ valid_partition,
156
+ cache_num=self.cache_num // 4,
157
+ cache_rate=self.cache_rate,
158
+ num_workers=self.num_workers,
159
+ transform=self.val_transforms(),
160
+ )
161
+ else:
162
+ train_ds = PersistentDataset(
163
+ train_partition,
164
+ transform=self.train_transforms(),
165
+ cache_dir=self.cache_dir,
166
+ )
167
+ valid_ds = PersistentDataset(
168
+ valid_partition,
169
+ transform=self.val_transforms(),
170
+ cache_dir=self.cache_dir,
171
+ )
172
+
173
+ return {"train": train_ds, "validation": valid_ds}
174
+
175
+ if stage in [None, "test"]:
176
+ if any([self.cache_num, self.cache_rate]) > 0:
177
+ test_ds = CacheDataset(
178
+ self.val_list,
179
+ cache_num=self.cache_num // 4,
180
+ cache_rate=self.cache_rate,
181
+ num_workers=self.num_workers,
182
+ transform=self.val_transforms(),
183
+ )
184
+ else:
185
+ test_ds = PersistentDataset(
186
+ self.val_list,
187
+ transform=self.val_transforms(),
188
+ cache_dir=self.cache_dir,
189
+ )
190
+
191
+ return {"test": test_ds}
192
+
193
+ return {"train": None, "validation": None}
194
+
195
+ def train_dataloader(self, train_ds):
196
+ # def collate_fn(examples):
197
+ # pixel_values = torch.stack([example["image"] for example in examples])
198
+ # mask = torch.stack([example["mask"] for example in examples])
199
+ # return {"pixel_values": pixel_values, "bool_masked_pos": mask}
200
+
201
+ return torch.utils.data.DataLoader(
202
+ train_ds,
203
+ batch_size=self.batch_size,
204
+ num_workers=self.num_workers,
205
+ pin_memory=True,
206
+ shuffle=True,
207
+ collate_fn=pad_list_data_collate,
208
+ # collate_fn=collate_fn
209
+ # drop_last=False,
210
+ # prefetch_factor=4,
211
+ )
212
+
213
+ def val_dataloader(self, valid_ds):
214
+ return torch.utils.data.DataLoader(
215
+ valid_ds,
216
+ batch_size=self.val_batch_size,
217
+ num_workers=self.num_workers,
218
+ pin_memory=True,
219
+ shuffle=False,
220
+ # drop_last=False,
221
+ collate_fn=pad_list_data_collate,
222
+ # prefetch_factor=4,
223
+ )
224
+
225
+ def test_dataloader(self, test_ds):
226
+ return torch.utils.data.DataLoader(
227
+ test_ds,
228
+ batch_size=self.val_batch_size,
229
+ num_workers=self.num_workers,
230
+ pin_memory=True,
231
+ shuffle=False,
232
+ # drop_last=False,
233
+ collate_fn=pad_list_data_collate,
234
+ # prefetch_factor=4,
235
+ )
tumor_mask.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy import ndimage
3
+
4
+
5
+ def extract_tumor_and_peritumoral(mask_volume, peritumoral_margin=2, patch_size=(16, 16, 16)):
6
+ """
7
+ Extract tumor and peritumoral regions from a 3D annotation mask.
8
+ Flattens dilated mask into sequence of patches and creates position mask.
9
+
10
+ Parameters:
11
+ mask_volume: 3D numpy array (z, y, x) with tumor annotations (1 for tumor, 0 for background)
12
+ peritumoral_margin: Integer specifying the margin size (in voxels) for peritumoral region
13
+ patch_size: Tuple (z,y,x) specifying size of patches to use
14
+
15
+ Returns:
16
+ tumor_coords: List of coordinates (z, y, x) for tumor region
17
+ peritumoral_coords: List of coordinates (z, y, x) for peritumoral region
18
+ patch_mask: Binary mask indicating if patches contain tumor (1) or not (0)
19
+ """
20
+
21
+ # Get tumor coordinates
22
+ tumor_coords = np.where(mask_volume == 1)
23
+ tumor_coords = list(zip(tumor_coords[0], tumor_coords[1], tumor_coords[2]))
24
+
25
+ # Create dilated mask for peritumoral region
26
+ dilated_mask = ndimage.binary_dilation(
27
+ mask_volume,
28
+ structure=np.ones((peritumoral_margin * 2 + 1, peritumoral_margin * 2 + 1, peritumoral_margin * 2 + 1)),
29
+ )
30
+
31
+ # Create patch position mask
32
+ z_steps = mask_volume.shape[0] // patch_size[0]
33
+ y_steps = mask_volume.shape[1] // patch_size[1]
34
+ x_steps = mask_volume.shape[2] // patch_size[2]
35
+
36
+ patch_mask = np.zeros((z_steps, y_steps, x_steps))
37
+
38
+ for z in range(z_steps):
39
+ for y in range(y_steps):
40
+ for x in range(x_steps):
41
+ patch = dilated_mask[
42
+ z * patch_size[0] : (z + 1) * patch_size[0],
43
+ y * patch_size[1] : (y + 1) * patch_size[1],
44
+ x * patch_size[2] : (x + 1) * patch_size[2],
45
+ ]
46
+ if np.any(patch):
47
+ patch_mask[z, y, x] = 1
48
+
49
+ return tumor_coords, patch_mask.flatten()
50
+
51
+
52
+ # Example usage
53
+ def main():
54
+ # Create sample data for testing
55
+ volume_shape = (96, 96, 96)
56
+ mask_volume = np.zeros(volume_shape)
57
+
58
+ # Create a synthetic tumor mask in the middle
59
+ mask_volume[40:60, 40:60, 40:60] = 1
60
+
61
+ # Test parameters
62
+ patch_size = (16, 16, 16)
63
+ peritumoral_margin = 5
64
+
65
+ # Call function and get results
66
+ tumor_coords, patch_mask = extract_tumor_and_peritumoral(
67
+ mask_volume, peritumoral_margin=peritumoral_margin, patch_size=patch_size
68
+ )
69
+
70
+ # Print test results
71
+ print(f"Volume shape: {volume_shape}")
72
+ print(f"Tumor volume: {len(tumor_coords)}")
73
+ print(f"Number of total patches: {patch_mask.shape}")
74
+ print(f"Number of patches containing tumor/peritumoral region: {np.sum(patch_mask)}")
75
+
76
+ # Validate results
77
+ assert len(tumor_coords) > 0, "No tumor coordinates found"
78
+ assert len(patch_mask) == np.prod(np.array(volume_shape) // np.array(patch_size)), "Incorrect patch mask size"
79
+
80
+ return tumor_coords, patch_mask
81
+
82
+
83
+ if __name__ == "__main__":
84
+ tumor_coords, patch_mask = main()