Spaces:
Sleeping
Sleeping
File size: 12,860 Bytes
5b24075 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 |
import random
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from pytorch_lightning import LightningDataModule
import os
import re
import yaml
import rasterio
import dvc.api
params = dvc.api.params_show()
N_TIMESTEPS = params['number_of_timesteps']
class ToTensorTransform(object):
def __init__(self, dtype):
self.dtype = dtype
def __call__(self, data):
return torch.tensor(data, dtype=self.dtype)
class NormalizeTransform(object):
def __init__(self, means, stds):
self.mean = means
self.std = stds
def __call__(self, data):
return transforms.Normalize(self.mean, self.std)(data)
class PermuteTransform:
def __call__(self, data):
height, width = data.shape[-2:]
# Ensure the channel dimension is as expected
if data.shape[0] != N_TIMESTEPS * 6:
raise ValueError(f"Expected {N_TIMESTEPS*6} channels, got {data.shape[1]}")
# Step 1: Reshape the data to group the N_TIMESTEPS*6 bands into N_TIMESTEPS groups of 6 bands
data = data.view(N_TIMESTEPS, 6, height, width)
# Step 2: Permute to bring the bands to the front
data = data.permute(1, 0, 2, 3) # NOTE: Prithvi wants it bands first # after this, shape is (6, N_TIMESTEPS, height, width)
return data
class RandomFlipAndJitterTransform:
"""
Apply random horizontal and vertical flips, and channel jitter to the input image and corresponding mask.
Parameters:
-----------
flip_prob : float, optional (default=0.5)
Probability of applying horizontal and vertical flips to the image and mask.
Each flip (horizontal and vertical) is applied independently based on this probability.
jitter_std : float, optional (default=0.02)
Standard deviation of the Gaussian noise added to the image channels for jitter.
This value controls the intensity of the random noise applied to the image channels.
Effects of Parameters:
----------------------
flip_prob:
- Higher flip_prob increases the likelihood of the image and mask being flipped.
- A value of 0 means no flipping, while a value of 1 means always flip.
jitter_std:
- Higher jitter_std increases the intensity of the noise added to the image channels.
- A value of 0 means no noise, while larger values add more significant noise.
"""
def __init__(self, flip_prob=0.5, jitter_std=0.02):
self.flip_prob = flip_prob
self.jitter_std = jitter_std
def __call__(self, img, mask, field_ids):
# Shapes (..., H, W)| img: torch.Size([6, N_TIMESTEPS, 224, 224]), mask: torch.Size([N_TIMESTEPS, 224, 224]), field_ids: torch.Size([1, 224, 224])
# Temporarily convert field_ids to int32 for flipping (flip not implemented for uint16)
field_ids = field_ids.to(torch.int32)
# Random horizontal flip
if random.random() < self.flip_prob:
img = torch.flip(img, [2])
mask = torch.flip(mask, [1])
field_ids = torch.flip(field_ids, [1])
# Random vertical flip
if random.random() < self.flip_prob:
img = torch.flip(img, [3])
mask = torch.flip(mask, [2])
field_ids = torch.flip(field_ids, [2])
# Convert field_ids back to uint16
field_ids = field_ids.to(torch.uint16)
# Channel jitter
noise = torch.randn(img.size()) * self.jitter_std
img += noise
return img, mask, field_ids
def get_img_transforms():
return transforms.Compose([])
def get_mask_transforms():
return transforms.Compose([])
class GeospatialDataset(Dataset):
def __init__(self, data_dir, fold_indicies, transform_img=None, transform_mask=None, transform_field_ids=None, debug=False, subset_size=None, data_augmentation=None):
self.data_dir = data_dir
self.chips_dir = os.path.join(data_dir, 'chips')
self.transform_img = transform_img
self.transform_mask = transform_mask
self.transform_field_ids = transform_field_ids
self.debug = debug
self.images = []
self.masks = []
self.field_ids = []
self.data_augmentation = data_augmentation
self.means, self.stds = self.load_stats(fold_indicies, N_TIMESTEPS)
self.transform_img_load = self.get_img_load_transforms(self.means, self.stds)
self.transform_mask_load = self.get_mask_load_transforms()
self.transform_field_ids_load = self.get_field_ids_load_transforms()
# Adjust file selection based on fold
for file in os.listdir(self.chips_dir):
if re.match(f".*_fold_[{''.join([str(f) for f in fold_indicies])}]_merged.tif", file):
self.images.append(file)
mask_file = file.replace("_merged.tif", "_mask.tif")
self.masks.append(mask_file)
field_ids_file = file.replace("_merged.tif", "_field_ids.tif")
self.field_ids.append(field_ids_file)
assert len(self.images) == len(self.masks), "Number of images and masks do not match"
# If subset_size is specified, randomly select a subset of the data
if subset_size is not None and len(self.images) > subset_size:
print(f"Randomly selecting {subset_size} samples from {len(self.images)} samples.")
selected_indices = random.sample(range(len(self.images)), subset_size)
self.images = [self.images[i] for i in selected_indices]
self.masks = [self.masks[i] for i in selected_indices]
self.field_ids = [self.field_ids[i] for i in selected_indices]
def load_stats(self, fold_indicies, n_timesteps=3):
"""Load normalization statistics for dataset from YAML file."""
stats_path = os.path.join(self.data_dir, 'chips_stats.yaml')
if self.debug:
print(f"Loading mean/std stats from {stats_path}")
assert os.path.exists(stats_path), f"mean/std stats file for dataset not found at {stats_path}"
with open(stats_path, 'r') as file:
stats = yaml.safe_load(file)
mean_list, std_list, n_list = [], [], []
for fold in fold_indicies:
key = f'fold_{fold}'
if key not in stats:
raise ValueError(f"mean/std stats for fold {fold} not found in {stats_path}")
if self.debug:
print(f"Stats with selected test fold {fold}: {stats[key]} over {n_timesteps} timesteps.")
mean_list.append(torch.Tensor(stats[key]['mean'])) # list of 6 means
std_list.append(torch.Tensor(stats[key]['std'])) # list of 6 stds
n_list.append(stats[key]['n_chips']) # list of 6 ns
# aggregate means and stds over all folds
means, stds = [], []
for channel in range(mean_list[0].shape[0]):
means.append(torch.stack([mean_list[i][channel] for i in range(len(mean_list))]).mean())
# stds are waaaay more complex to aggregate
# \sqrt{\frac{\sum_{i=1}^{n} (\sigma_i * (n_i - 1))}{\sum_{i=1}^{n} (n_i) - n}}
variances = torch.stack([std_list[i][channel] ** 2 for i in range(len(std_list))])
n = torch.tensor([n_list[i] for i in range(len(n_list))], dtype=torch.float32)
combined_variance = torch.sum(variances * (n - 1)) / (torch.sum(n) - len(n_list))
stds.append(torch.sqrt(combined_variance))
# make means and stds into 2d arrays, as torchvision would otherwise convert it into a 3d tensor which is incompatible with our 4d temporal images
# https://github.com/pytorch/vision/blob/6e18cea3485066b7277785415bf2e0422dbdb9da/torchvision/transforms/_functional_tensor.py#L923
return means * n_timesteps, stds * n_timesteps
def get_img_load_transforms(self, means, stds):
return transforms.Compose([
ToTensorTransform(torch.float32),
NormalizeTransform(means, stds),
PermuteTransform()
])
def get_mask_load_transforms(self):
return transforms.Compose([
ToTensorTransform(torch.uint8)
])
def get_field_ids_load_transforms(self):
return transforms.Compose([
ToTensorTransform(torch.uint16)
])
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = os.path.join(self.chips_dir, self.images[idx])
mask_path = os.path.join(self.chips_dir, self.masks[idx])
field_ids_path = os.path.join(self.chips_dir, self.field_ids[idx])
img = rasterio.open(img_path).read().astype('uint16')
mask = rasterio.open(mask_path).read().astype('uint8')
field_ids = rasterio.open(field_ids_path).read().astype('uint16')
# Apply our base transforms
img = self.transform_img_load(img)
mask = self.transform_mask_load(mask)
field_ids = self.transform_field_ids_load(field_ids)
# Apply additional transforms passed from GeospatialDataModule if applicable
if self.transform_img is not None:
img = self.transform_img(img)
if self.transform_mask is not None:
mask = self.transform_mask(mask)
if self.transform_field_ids is not None:
field_ids = self.transform_field_ids(field_ids)
# Apply data augmentation if enabled
if self.data_augmentation is not None and self.data_augmentation.get('enabled', True):
img, mask, field_ids = RandomFlipAndJitterTransform(
flip_prob=self.data_augmentation.get('flip_prob', 0.5),
jitter_std=self.data_augmentation.get('jitter_std', 0.02)
)(img, mask, field_ids)
# Load targets for given tiers
num_tiers = mask.shape[0]
targets = ()
for i in range(num_tiers):
targets += (mask[i, :, :].type(torch.long),)
return img, (targets, field_ids)
class GeospatialDataModule(LightningDataModule):
def __init__(self, data_dir, train_folds, val_folds, test_folds, batch_size=8, num_workers=4, debug=False, subsets=None, data_augmentation=None):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.debug = debug
self.subsets = subsets if subsets is not None else {}
self.data_augmentation = data_augmentation if data_augmentation is not None else {}
GeospatialDataModule.validate_folds(train_folds, val_folds, test_folds)
self.train_folds = train_folds
self.val_folds = val_folds
self.test_folds = test_folds
# NOTE: Transforms on this level not used for now
self.transform_img = get_img_transforms()
self.transform_mask = get_mask_transforms()
@staticmethod
def validate_folds(train, val, test):
if train is None or val is None or test is None:
raise ValueError("All fold sets must be specified")
if len(set(train) & set(val)) > 0 or len(set(train) & set(test)) > 0 or len(set(val) & set(test)) > 0:
raise ValueError("Folds must be mutually exclusive")
def setup(self, stage=None):
print(f"Setting up GeospatialDataModule for stage: {stage}. Data augmentation config: {self.data_augmentation}")
common_params = {
'data_dir': self.data_dir,
'debug': self.debug,
'data_augmentation': self.data_augmentation
}
common_params_val_test = { # Never augment validation or test data
**common_params,
'data_augmentation': {
'enabled': False
}
}
if stage in ('fit', None):
self.train_dataset = GeospatialDataset(fold_indicies=self.train_folds, subset_size=self.subsets.get('train', None), **common_params)
self.val_dataset = GeospatialDataset(fold_indicies=self.val_folds, subset_size=self.subsets.get('val', None), **common_params_val_test)
if stage in ('test', None):
self.test_dataset = GeospatialDataset(fold_indicies=self.test_folds, subset_size=self.subsets.get('test', None), **common_params_val_test)
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=True, shuffle=True)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=True)
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=True)
|