|
import random |
|
import torch |
|
from torch.utils.data import Dataset, DataLoader |
|
from torchvision import transforms |
|
from pytorch_lightning import LightningDataModule |
|
import os |
|
import re |
|
import yaml |
|
import rasterio |
|
import dvc.api |
|
|
|
|
|
params = dvc.api.params_show() |
|
N_TIMESTEPS = params['number_of_timesteps'] |
|
|
|
class ToTensorTransform(object): |
|
def __init__(self, dtype): |
|
self.dtype = dtype |
|
|
|
def __call__(self, data): |
|
return torch.tensor(data, dtype=self.dtype) |
|
|
|
class NormalizeTransform(object): |
|
def __init__(self, means, stds): |
|
self.mean = means |
|
self.std = stds |
|
|
|
def __call__(self, data): |
|
return transforms.Normalize(self.mean, self.std)(data) |
|
|
|
class PermuteTransform: |
|
def __call__(self, data): |
|
height, width = data.shape[-2:] |
|
|
|
|
|
if data.shape[0] != N_TIMESTEPS * 6: |
|
raise ValueError(f"Expected {N_TIMESTEPS*6} channels, got {data.shape[1]}") |
|
|
|
|
|
data = data.view(N_TIMESTEPS, 6, height, width) |
|
|
|
|
|
data = data.permute(1, 0, 2, 3) |
|
return data |
|
|
|
class RandomFlipAndJitterTransform: |
|
""" |
|
Apply random horizontal and vertical flips, and channel jitter to the input image and corresponding mask. |
|
|
|
Parameters: |
|
----------- |
|
flip_prob : float, optional (default=0.5) |
|
Probability of applying horizontal and vertical flips to the image and mask. |
|
Each flip (horizontal and vertical) is applied independently based on this probability. |
|
|
|
jitter_std : float, optional (default=0.02) |
|
Standard deviation of the Gaussian noise added to the image channels for jitter. |
|
This value controls the intensity of the random noise applied to the image channels. |
|
|
|
Effects of Parameters: |
|
---------------------- |
|
flip_prob: |
|
- Higher flip_prob increases the likelihood of the image and mask being flipped. |
|
- A value of 0 means no flipping, while a value of 1 means always flip. |
|
|
|
jitter_std: |
|
- Higher jitter_std increases the intensity of the noise added to the image channels. |
|
- A value of 0 means no noise, while larger values add more significant noise. |
|
""" |
|
def __init__(self, flip_prob=0.5, jitter_std=0.02): |
|
self.flip_prob = flip_prob |
|
self.jitter_std = jitter_std |
|
|
|
def __call__(self, img, mask, field_ids): |
|
|
|
|
|
|
|
field_ids = field_ids.to(torch.int32) |
|
|
|
|
|
if random.random() < self.flip_prob: |
|
img = torch.flip(img, [2]) |
|
mask = torch.flip(mask, [1]) |
|
field_ids = torch.flip(field_ids, [1]) |
|
|
|
|
|
if random.random() < self.flip_prob: |
|
img = torch.flip(img, [3]) |
|
mask = torch.flip(mask, [2]) |
|
field_ids = torch.flip(field_ids, [2]) |
|
|
|
|
|
field_ids = field_ids.to(torch.uint16) |
|
|
|
|
|
noise = torch.randn(img.size()) * self.jitter_std |
|
img += noise |
|
|
|
return img, mask, field_ids |
|
|
|
def get_img_transforms(): |
|
return transforms.Compose([]) |
|
|
|
def get_mask_transforms(): |
|
return transforms.Compose([]) |
|
|
|
class GeospatialDataset(Dataset): |
|
def __init__(self, data_dir, fold_indicies, transform_img=None, transform_mask=None, transform_field_ids=None, debug=False, subset_size=None, data_augmentation=None): |
|
self.data_dir = data_dir |
|
self.chips_dir = os.path.join(data_dir, 'chips') |
|
self.transform_img = transform_img |
|
self.transform_mask = transform_mask |
|
self.transform_field_ids = transform_field_ids |
|
self.debug = debug |
|
self.images = [] |
|
self.masks = [] |
|
self.field_ids = [] |
|
self.data_augmentation = data_augmentation |
|
|
|
self.means, self.stds = self.load_stats(fold_indicies, N_TIMESTEPS) |
|
self.transform_img_load = self.get_img_load_transforms(self.means, self.stds) |
|
self.transform_mask_load = self.get_mask_load_transforms() |
|
self.transform_field_ids_load = self.get_field_ids_load_transforms() |
|
|
|
|
|
for file in os.listdir(self.chips_dir): |
|
if re.match(f".*_fold_[{''.join([str(f) for f in fold_indicies])}]_merged.tif", file): |
|
self.images.append(file) |
|
mask_file = file.replace("_merged.tif", "_mask.tif") |
|
self.masks.append(mask_file) |
|
field_ids_file = file.replace("_merged.tif", "_field_ids.tif") |
|
self.field_ids.append(field_ids_file) |
|
|
|
assert len(self.images) == len(self.masks), "Number of images and masks do not match" |
|
|
|
|
|
if subset_size is not None and len(self.images) > subset_size: |
|
print(f"Randomly selecting {subset_size} samples from {len(self.images)} samples.") |
|
selected_indices = random.sample(range(len(self.images)), subset_size) |
|
self.images = [self.images[i] for i in selected_indices] |
|
self.masks = [self.masks[i] for i in selected_indices] |
|
self.field_ids = [self.field_ids[i] for i in selected_indices] |
|
|
|
def load_stats(self, fold_indicies, n_timesteps=3): |
|
"""Load normalization statistics for dataset from YAML file.""" |
|
stats_path = os.path.join(self.data_dir, 'chips_stats.yaml') |
|
if self.debug: |
|
print(f"Loading mean/std stats from {stats_path}") |
|
assert os.path.exists(stats_path), f"mean/std stats file for dataset not found at {stats_path}" |
|
with open(stats_path, 'r') as file: |
|
stats = yaml.safe_load(file) |
|
mean_list, std_list, n_list = [], [], [] |
|
for fold in fold_indicies: |
|
key = f'fold_{fold}' |
|
if key not in stats: |
|
raise ValueError(f"mean/std stats for fold {fold} not found in {stats_path}") |
|
if self.debug: |
|
print(f"Stats with selected test fold {fold}: {stats[key]} over {n_timesteps} timesteps.") |
|
mean_list.append(torch.Tensor(stats[key]['mean'])) |
|
std_list.append(torch.Tensor(stats[key]['std'])) |
|
n_list.append(stats[key]['n_chips']) |
|
|
|
means, stds = [], [] |
|
for channel in range(mean_list[0].shape[0]): |
|
means.append(torch.stack([mean_list[i][channel] for i in range(len(mean_list))]).mean()) |
|
|
|
|
|
variances = torch.stack([std_list[i][channel] ** 2 for i in range(len(std_list))]) |
|
n = torch.tensor([n_list[i] for i in range(len(n_list))], dtype=torch.float32) |
|
combined_variance = torch.sum(variances * (n - 1)) / (torch.sum(n) - len(n_list)) |
|
stds.append(torch.sqrt(combined_variance)) |
|
|
|
|
|
|
|
return means * n_timesteps, stds * n_timesteps |
|
|
|
def get_img_load_transforms(self, means, stds): |
|
return transforms.Compose([ |
|
ToTensorTransform(torch.float32), |
|
NormalizeTransform(means, stds), |
|
PermuteTransform() |
|
]) |
|
|
|
def get_mask_load_transforms(self): |
|
return transforms.Compose([ |
|
ToTensorTransform(torch.uint8) |
|
]) |
|
|
|
def get_field_ids_load_transforms(self): |
|
return transforms.Compose([ |
|
ToTensorTransform(torch.uint16) |
|
]) |
|
|
|
def __len__(self): |
|
return len(self.images) |
|
|
|
def __getitem__(self, idx): |
|
img_path = os.path.join(self.chips_dir, self.images[idx]) |
|
mask_path = os.path.join(self.chips_dir, self.masks[idx]) |
|
field_ids_path = os.path.join(self.chips_dir, self.field_ids[idx]) |
|
|
|
img = rasterio.open(img_path).read().astype('uint16') |
|
mask = rasterio.open(mask_path).read().astype('uint8') |
|
field_ids = rasterio.open(field_ids_path).read().astype('uint16') |
|
|
|
|
|
img = self.transform_img_load(img) |
|
mask = self.transform_mask_load(mask) |
|
field_ids = self.transform_field_ids_load(field_ids) |
|
|
|
|
|
if self.transform_img is not None: |
|
img = self.transform_img(img) |
|
if self.transform_mask is not None: |
|
mask = self.transform_mask(mask) |
|
if self.transform_field_ids is not None: |
|
field_ids = self.transform_field_ids(field_ids) |
|
|
|
|
|
if self.data_augmentation is not None and self.data_augmentation.get('enabled', True): |
|
img, mask, field_ids = RandomFlipAndJitterTransform( |
|
flip_prob=self.data_augmentation.get('flip_prob', 0.5), |
|
jitter_std=self.data_augmentation.get('jitter_std', 0.02) |
|
)(img, mask, field_ids) |
|
|
|
|
|
num_tiers = mask.shape[0] |
|
targets = () |
|
for i in range(num_tiers): |
|
targets += (mask[i, :, :].type(torch.long),) |
|
|
|
return img, (targets, field_ids) |
|
|
|
class GeospatialDataModule(LightningDataModule): |
|
def __init__(self, data_dir, train_folds, val_folds, test_folds, batch_size=8, num_workers=4, debug=False, subsets=None, data_augmentation=None): |
|
super().__init__() |
|
self.data_dir = data_dir |
|
self.batch_size = batch_size |
|
self.num_workers = num_workers |
|
self.debug = debug |
|
self.subsets = subsets if subsets is not None else {} |
|
self.data_augmentation = data_augmentation if data_augmentation is not None else {} |
|
|
|
GeospatialDataModule.validate_folds(train_folds, val_folds, test_folds) |
|
self.train_folds = train_folds |
|
self.val_folds = val_folds |
|
self.test_folds = test_folds |
|
|
|
|
|
self.transform_img = get_img_transforms() |
|
self.transform_mask = get_mask_transforms() |
|
|
|
@staticmethod |
|
def validate_folds(train, val, test): |
|
if train is None or val is None or test is None: |
|
raise ValueError("All fold sets must be specified") |
|
|
|
if len(set(train) & set(val)) > 0 or len(set(train) & set(test)) > 0 or len(set(val) & set(test)) > 0: |
|
raise ValueError("Folds must be mutually exclusive") |
|
|
|
def setup(self, stage=None): |
|
print(f"Setting up GeospatialDataModule for stage: {stage}. Data augmentation config: {self.data_augmentation}") |
|
common_params = { |
|
'data_dir': self.data_dir, |
|
'debug': self.debug, |
|
'data_augmentation': self.data_augmentation |
|
} |
|
common_params_val_test = { |
|
**common_params, |
|
'data_augmentation': { |
|
'enabled': False |
|
} |
|
} |
|
if stage in ('fit', None): |
|
self.train_dataset = GeospatialDataset(fold_indicies=self.train_folds, subset_size=self.subsets.get('train', None), **common_params) |
|
self.val_dataset = GeospatialDataset(fold_indicies=self.val_folds, subset_size=self.subsets.get('val', None), **common_params_val_test) |
|
if stage in ('test', None): |
|
self.test_dataset = GeospatialDataset(fold_indicies=self.test_folds, subset_size=self.subsets.get('test', None), **common_params_val_test) |
|
|
|
def train_dataloader(self): |
|
return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=True, shuffle=True) |
|
|
|
def val_dataloader(self): |
|
return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=True) |
|
|
|
def test_dataloader(self): |
|
return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=True) |
|
|