File size: 12,860 Bytes
5b24075
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
import random
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from pytorch_lightning import LightningDataModule
import os
import re
import yaml
import rasterio
import dvc.api


params = dvc.api.params_show()
N_TIMESTEPS = params['number_of_timesteps']

class ToTensorTransform(object):
    def __init__(self, dtype):
        self.dtype = dtype

    def __call__(self, data):
        return torch.tensor(data, dtype=self.dtype)
    
class NormalizeTransform(object):
    def __init__(self, means, stds):
        self.mean = means
        self.std = stds

    def __call__(self, data):
        return transforms.Normalize(self.mean, self.std)(data)
    
class PermuteTransform:
    def __call__(self, data):
        height, width = data.shape[-2:]

        # Ensure the channel dimension is as expected
        if data.shape[0] != N_TIMESTEPS * 6:
            raise ValueError(f"Expected {N_TIMESTEPS*6} channels, got {data.shape[1]}")
        
        # Step 1: Reshape the data to group the N_TIMESTEPS*6 bands into N_TIMESTEPS groups of 6 bands
        data = data.view(N_TIMESTEPS, 6, height, width)
        
        # Step 2: Permute to bring the bands to the front
        data = data.permute(1, 0, 2, 3)  # NOTE: Prithvi wants it bands first # after this, shape is (6, N_TIMESTEPS, height, width)
        return data

class RandomFlipAndJitterTransform:
    """
    Apply random horizontal and vertical flips, and channel jitter to the input image and corresponding mask.

    Parameters:
    -----------
    flip_prob : float, optional (default=0.5)
        Probability of applying horizontal and vertical flips to the image and mask.
        Each flip (horizontal and vertical) is applied independently based on this probability.

    jitter_std : float, optional (default=0.02)
        Standard deviation of the Gaussian noise added to the image channels for jitter.
        This value controls the intensity of the random noise applied to the image channels.

    Effects of Parameters:
    ----------------------
    flip_prob:
        - Higher flip_prob increases the likelihood of the image and mask being flipped.
        - A value of 0 means no flipping, while a value of 1 means always flip.

    jitter_std:
        - Higher jitter_std increases the intensity of the noise added to the image channels.
        - A value of 0 means no noise, while larger values add more significant noise.
    """
    def __init__(self, flip_prob=0.5, jitter_std=0.02):
        self.flip_prob = flip_prob
        self.jitter_std = jitter_std

    def __call__(self, img, mask, field_ids):
        # Shapes (..., H, W)| img: torch.Size([6, N_TIMESTEPS, 224, 224]), mask: torch.Size([N_TIMESTEPS, 224, 224]), field_ids: torch.Size([1, 224, 224])
        
        # Temporarily convert field_ids to int32 for flipping (flip not implemented for uint16)
        field_ids = field_ids.to(torch.int32)

        # Random horizontal flip
        if random.random() < self.flip_prob:
            img = torch.flip(img, [2])
            mask = torch.flip(mask, [1])
            field_ids = torch.flip(field_ids, [1])

        # Random vertical flip
        if random.random() < self.flip_prob:
            img = torch.flip(img, [3])
            mask = torch.flip(mask, [2])
            field_ids = torch.flip(field_ids, [2])

        # Convert field_ids back to uint16
        field_ids = field_ids.to(torch.uint16)

        # Channel jitter
        noise = torch.randn(img.size()) * self.jitter_std
        img += noise

        return img, mask, field_ids

def get_img_transforms():
    return transforms.Compose([])

def get_mask_transforms():
    return transforms.Compose([])

class GeospatialDataset(Dataset):
    def __init__(self, data_dir, fold_indicies, transform_img=None, transform_mask=None, transform_field_ids=None, debug=False, subset_size=None, data_augmentation=None):
        self.data_dir = data_dir
        self.chips_dir = os.path.join(data_dir, 'chips')
        self.transform_img = transform_img
        self.transform_mask = transform_mask
        self.transform_field_ids = transform_field_ids
        self.debug = debug
        self.images = []
        self.masks = []
        self.field_ids = []
        self.data_augmentation = data_augmentation

        self.means, self.stds = self.load_stats(fold_indicies, N_TIMESTEPS)
        self.transform_img_load = self.get_img_load_transforms(self.means, self.stds)
        self.transform_mask_load = self.get_mask_load_transforms()
        self.transform_field_ids_load = self.get_field_ids_load_transforms()
        
        # Adjust file selection based on fold
        for file in os.listdir(self.chips_dir):
            if re.match(f".*_fold_[{''.join([str(f) for f in fold_indicies])}]_merged.tif", file):
                self.images.append(file)
                mask_file = file.replace("_merged.tif", "_mask.tif")
                self.masks.append(mask_file)
                field_ids_file = file.replace("_merged.tif", "_field_ids.tif")
                self.field_ids.append(field_ids_file)

        assert len(self.images) == len(self.masks), "Number of images and masks do not match"

        # If subset_size is specified, randomly select a subset of the data
        if subset_size is not None and len(self.images) > subset_size:
            print(f"Randomly selecting {subset_size} samples from {len(self.images)} samples.")
            selected_indices = random.sample(range(len(self.images)), subset_size)
            self.images = [self.images[i] for i in selected_indices]
            self.masks = [self.masks[i] for i in selected_indices]
            self.field_ids = [self.field_ids[i] for i in selected_indices]

    def load_stats(self, fold_indicies, n_timesteps=3):
        """Load normalization statistics for dataset from YAML file."""
        stats_path = os.path.join(self.data_dir, 'chips_stats.yaml')
        if self.debug:
            print(f"Loading mean/std stats from {stats_path}")
        assert os.path.exists(stats_path), f"mean/std stats file for dataset not found at {stats_path}"
        with open(stats_path, 'r') as file:
            stats = yaml.safe_load(file)
        mean_list, std_list, n_list = [], [], []
        for fold in fold_indicies:
            key = f'fold_{fold}'
            if key not in stats:
                raise ValueError(f"mean/std stats for fold {fold} not found in {stats_path}")
            if self.debug:
                print(f"Stats with selected test fold {fold}: {stats[key]} over {n_timesteps} timesteps.")
            mean_list.append(torch.Tensor(stats[key]['mean'])) # list of 6 means
            std_list.append(torch.Tensor(stats[key]['std'])) # list of 6 stds
            n_list.append(stats[key]['n_chips']) # list of 6 ns
        # aggregate means and stds over all folds
        means, stds = [], []
        for channel in range(mean_list[0].shape[0]):
            means.append(torch.stack([mean_list[i][channel] for i in range(len(mean_list))]).mean())
            # stds are waaaay more complex to aggregate
            # \sqrt{\frac{\sum_{i=1}^{n} (\sigma_i * (n_i - 1))}{\sum_{i=1}^{n} (n_i) - n}}
            variances = torch.stack([std_list[i][channel] ** 2 for i in range(len(std_list))])
            n = torch.tensor([n_list[i] for i in range(len(n_list))], dtype=torch.float32)
            combined_variance = torch.sum(variances * (n - 1)) / (torch.sum(n) - len(n_list))
            stds.append(torch.sqrt(combined_variance))

        # make means and stds into 2d arrays, as torchvision would otherwise convert it into a 3d tensor which is incompatible with our 4d temporal images
        # https://github.com/pytorch/vision/blob/6e18cea3485066b7277785415bf2e0422dbdb9da/torchvision/transforms/_functional_tensor.py#L923
        return means * n_timesteps, stds * n_timesteps

    def get_img_load_transforms(self, means, stds):
        return transforms.Compose([
            ToTensorTransform(torch.float32),
            NormalizeTransform(means, stds),
            PermuteTransform()
        ])

    def get_mask_load_transforms(self):
        return transforms.Compose([
            ToTensorTransform(torch.uint8)
        ])
        
    def get_field_ids_load_transforms(self):
        return transforms.Compose([
            ToTensorTransform(torch.uint16)
        ])
    
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.chips_dir, self.images[idx])
        mask_path = os.path.join(self.chips_dir, self.masks[idx])
        field_ids_path = os.path.join(self.chips_dir, self.field_ids[idx])
        
        img = rasterio.open(img_path).read().astype('uint16')
        mask = rasterio.open(mask_path).read().astype('uint8')
        field_ids = rasterio.open(field_ids_path).read().astype('uint16')
        
        # Apply our base transforms
        img = self.transform_img_load(img)
        mask = self.transform_mask_load(mask)
        field_ids = self.transform_field_ids_load(field_ids)

        # Apply additional transforms passed from GeospatialDataModule if applicable
        if self.transform_img is not None:
            img = self.transform_img(img)
        if self.transform_mask is not None:
            mask = self.transform_mask(mask)
        if self.transform_field_ids is not None:
            field_ids = self.transform_field_ids(field_ids)

        # Apply data augmentation if enabled
        if self.data_augmentation is not None and self.data_augmentation.get('enabled', True):
            img, mask, field_ids = RandomFlipAndJitterTransform(
                flip_prob=self.data_augmentation.get('flip_prob', 0.5), 
                jitter_std=self.data_augmentation.get('jitter_std', 0.02)
            )(img, mask, field_ids)

        # Load targets for given tiers
        num_tiers = mask.shape[0]
        targets = ()
        for i in range(num_tiers):
            targets += (mask[i, :, :].type(torch.long),)
        
        return img, (targets, field_ids)

class GeospatialDataModule(LightningDataModule):
    def __init__(self, data_dir, train_folds, val_folds, test_folds, batch_size=8, num_workers=4, debug=False, subsets=None, data_augmentation=None):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.debug = debug
        self.subsets = subsets if subsets is not None else {}
        self.data_augmentation = data_augmentation if data_augmentation is not None else {}
        
        GeospatialDataModule.validate_folds(train_folds, val_folds, test_folds)
        self.train_folds = train_folds
        self.val_folds = val_folds
        self.test_folds = test_folds

        NOTE: Transforms on this level not used for now
        self.transform_img = get_img_transforms()
        self.transform_mask = get_mask_transforms()

    @staticmethod
    def validate_folds(train, val, test):
        if train is None or val is None or test is None:
            raise ValueError("All fold sets must be specified")
        
        if len(set(train) & set(val)) > 0 or len(set(train) & set(test)) > 0 or len(set(val) & set(test)) > 0:
            raise ValueError("Folds must be mutually exclusive")

    def setup(self, stage=None):
        print(f"Setting up GeospatialDataModule for stage: {stage}. Data augmentation config: {self.data_augmentation}")
        common_params = {
            'data_dir': self.data_dir,
            'debug': self.debug,
            'data_augmentation': self.data_augmentation
        }
        common_params_val_test = { # Never augment validation or test data
            **common_params,
             'data_augmentation': {
                'enabled': False
            }
        }
        if stage in ('fit', None):
            self.train_dataset = GeospatialDataset(fold_indicies=self.train_folds, subset_size=self.subsets.get('train', None), **common_params)
            self.val_dataset   = GeospatialDataset(fold_indicies=self.val_folds,   subset_size=self.subsets.get('val',   None), **common_params_val_test)
        if stage in ('test', None):
            self.test_dataset  = GeospatialDataset(fold_indicies=self.test_folds,  subset_size=self.subsets.get('test',  None), **common_params_val_test)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=True, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=True)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=True)