|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
EPSILONDIV = 1e-4 |
|
WLMID = torch.tensor([462, 655.5, 843, 1599], dtype=torch.float32) |
|
BWIDTH = torch.tensor([48, 81, 142, 70], dtype=torch.float32) |
|
|
|
|
|
def safe_divide(numerator, denominator, eps=EPSILONDIV): |
|
denominator = torch.where(denominator < eps, torch.full_like(denominator, eps), denominator) |
|
return numerator / denominator |
|
|
|
def NDVI(image): |
|
nir = image[2:3] |
|
red = image[1:2] |
|
return safe_divide(nir - red, nir + red) |
|
|
|
def BLUENIRndsi(image): |
|
blue = image[0:1] |
|
nir = image[2:3] |
|
return safe_divide(blue - nir, blue + nir) |
|
|
|
def BLUESWIRndsi(image): |
|
blue = image[0:1] |
|
swir = image[3:4] |
|
return safe_divide(blue - swir, blue + swir) |
|
|
|
def REDSWIRratio(image): |
|
red = image[1:2] |
|
swir = image[3:4] |
|
return safe_divide(red, swir) |
|
|
|
def trapz(tensor, x): |
|
x = x.to(tensor.device) |
|
left = tensor[:-1] |
|
right = tensor[1:] |
|
dx = (x[1:] - x[:-1]).view(-1, 1, 1) |
|
return torch.sum(dx * (left + right) / 2.0, dim=0, keepdim=True) |
|
|
|
def whiteness(image): |
|
norm = torch.linalg.norm(image, dim=0, keepdim=True) |
|
norm = torch.where(norm < EPSILONDIV, torch.full_like(norm, EPSILONDIV), norm) |
|
normalized = image / norm |
|
ideal = 1.0 / torch.sqrt(torch.tensor(image.shape[0], dtype=torch.float32, device=image.device)) |
|
diff = torch.abs(normalized - ideal) |
|
wrange = WLMID[-1] - WLMID[0] |
|
return trapz(diff, WLMID) / wrange |
|
|
|
def brightness(image): |
|
wrange = WLMID[-1] - WLMID[0] |
|
return trapz(image, WLMID) / wrange |
|
|
|
def brightnessVIS(image): |
|
return brightness(image[0:2]) |
|
|
|
def brightnessNIR(image): |
|
return brightness(image[2:4]) |
|
|
|
def whitenessVIS(image): |
|
return whiteness(image[0:2]) |
|
|
|
def whitenessNIR(image): |
|
return whiteness(image[2:4]) |
|
|
|
|
|
|
|
|
|
def centered_avg_pool(x: torch.Tensor, size: int) -> torch.Tensor: |
|
""" |
|
x: (B, C, H, W) |
|
size: pooling window size |
|
Returns: same shape (B, C, H, W), average over a centered size×size patch. |
|
""" |
|
|
|
pad = size // 2 |
|
|
|
x_padded = F.pad(x, (pad, pad, pad, pad), mode="replicate") |
|
|
|
return F.avg_pool2d(x_padded, kernel_size=size, stride=1) |
|
|
|
def mconvolution(original_layer: torch.Tensor, |
|
size: int, |
|
maskconv: torch.Tensor = None) -> torch.Tensor: |
|
""" |
|
Mean convolution: centered average, optionally divided by maskconv. |
|
maskconv: same shape (B, C, H, W) or broadcastable |
|
""" |
|
avg = centered_avg_pool(original_layer, size) |
|
if maskconv is not None: |
|
avg = avg / maskconv |
|
return avg |
|
|
|
def sconvolution(original_layer: torch.Tensor, |
|
mean_layer: torch.Tensor, |
|
size: int, |
|
maskconv: torch.Tensor = None, |
|
zeros: torch.Tensor = None) -> torch.Tensor: |
|
""" |
|
Standard-deviation-like convolution: |
|
sqrt( E[x^2] - (E[x])^2 ) where E is centered average. |
|
zeros: tensor of zeros (B, C, H, W) or broadcastable; defaults to zeros_like(original_layer) |
|
""" |
|
|
|
avg2 = centered_avg_pool(original_layer * original_layer, size) |
|
if maskconv is not None: |
|
avg2 = avg2 / maskconv |
|
|
|
|
|
mean_sq = mean_layer * mean_layer |
|
|
|
|
|
if zeros is None: |
|
zeros = torch.zeros_like(original_layer) |
|
|
|
|
|
diff = avg2 - mean_sq |
|
out = torch.sqrt(torch.clamp(diff, min=0.0)) |
|
|
|
return torch.where(diff > 0, out, zeros) |
|
|
|
def standard_deviation_conv(X, fun=None, size=5, maskconv=None): |
|
|
|
if fun is not None: |
|
X = fun(X) |
|
|
|
|
|
mean_layer = mconvolution(X, size, maskconv=maskconv) |
|
|
|
|
|
std_layer = sconvolution(X, mean_layer, size, maskconv=maskconv) |
|
|
|
return std_layer |
|
|
|
def mean_conv(X, fun=None, size=5, maskconv=None): |
|
|
|
if fun is not None: |
|
X = fun(X) |
|
|
|
|
|
mean_layer = mconvolution(X, size, maskconv=maskconv) |
|
|
|
return mean_layer |
|
|
|
def covolvemask(maskvalid: torch.Tensor, size: int) -> torch.Tensor: |
|
""" |
|
maskvalid: torch.Tensor of dtype torch.bool or 0/1 ints, shape (H, W) |
|
size: the window size for centered avg pooling |
|
Returns: torch.FloatTensor of shape (H, W) |
|
""" |
|
|
|
mask_f = maskvalid.to(torch.float32).unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
mask_cov = centered_avg_pool(mask_f, size) |
|
|
|
|
|
ones = torch.ones_like(mask_cov) |
|
return torch.where(maskvalid, mask_cov, ones) |
|
|
|
def feature_generator(X, maskvalid=None): |
|
|
|
|
|
dims = (40, X.shape[-2], X.shape[-1]) |
|
features = torch.zeros(dims, dtype=torch.float32, device=X.device) |
|
|
|
|
|
TOA_REFL_BLUE = X[0] |
|
TOA_REFL_RED = X[1] |
|
TOA_REFL_NIR = X[2] |
|
TOA_REFL_SWIR = X[3] |
|
|
|
|
|
features[0] = TOA_REFL_BLUE |
|
features[30] = TOA_REFL_RED |
|
|
|
|
|
features[2] = whitenessVIS(X) |
|
features[3] = REDSWIRratio(X) |
|
features[7] = BLUENIRndsi(X) |
|
features[13] = NDVI(X) |
|
features[24] = BLUESWIRndsi(X) |
|
features[25] = whitenessNIR(X) |
|
features[36] = whiteness(X) |
|
features[37] = brightnessVIS(X) |
|
|
|
|
|
|
|
if maskvalid is not None: |
|
mask5 = covolvemask(maskvalid, 5) |
|
mask3 = covolvemask(maskvalid, 3) |
|
else: |
|
mask5 = None |
|
mask3 = None |
|
|
|
features[1] = standard_deviation_conv(X, whitenessVIS, 5, mask5) |
|
features[5] = standard_deviation_conv(X, NDVI, 5, mask5) |
|
features[11] = standard_deviation_conv(TOA_REFL_SWIR[None], None, 5, mask5) |
|
features[12] = standard_deviation_conv(X, REDSWIRratio, 5, mask5) |
|
features[15] = standard_deviation_conv(X, whiteness, 5, mask5) |
|
features[17] = standard_deviation_conv(TOA_REFL_BLUE[None], None, 5, mask5) |
|
features[18] = standard_deviation_conv(X, BLUESWIRndsi, 5, mask5) |
|
features[21] = standard_deviation_conv(X, BLUENIRndsi, 5, mask5) |
|
features[27] = standard_deviation_conv(X, whitenessVIS, 5, mask5) |
|
features[29] = standard_deviation_conv(X, brightnessNIR, 5, mask5) |
|
features[32] = standard_deviation_conv(TOA_REFL_RED[None], None, 5, mask5) |
|
features[35] = standard_deviation_conv(X, brightness, 5, mask5) |
|
features[38] = standard_deviation_conv(TOA_REFL_NIR[None], None, 5, mask5) |
|
|
|
|
|
features[10] = standard_deviation_conv(X, whitenessNIR, 3, mask3) |
|
features[19] = standard_deviation_conv(X, REDSWIRratio, 3, mask3) |
|
features[20] = standard_deviation_conv(X, NDVI, 3, mask3) |
|
features[23] = standard_deviation_conv(X, whitenessVIS, 3, mask3) |
|
features[26] = standard_deviation_conv(TOA_REFL_BLUE[None], None, 3, mask3) |
|
features[28] = standard_deviation_conv(X, whiteness, 3, mask3) |
|
features[31] = standard_deviation_conv(X, BLUESWIRndsi, 3, mask3) |
|
features[34] = standard_deviation_conv(X, BLUENIRndsi, 3, mask3) |
|
|
|
|
|
features[4] = mean_conv(X, NDVI, 5, mask5) |
|
features[6] = mean_conv(X, REDSWIRratio, 5, mask5) |
|
features[9] = mean_conv(X, whitenessVIS, 5, mask5) |
|
features[16] = mean_conv(X, whitenessNIR, 5, mask5) |
|
features[22] = mean_conv(X, BLUENIRndsi, 5, mask5) |
|
|
|
|
|
features[8] = mean_conv(X, whitenessVIS, 3, mask3) |
|
features[14] = mean_conv(X, NDVI, 3, mask3) |
|
features[33] = mean_conv(X, BLUENIRndsi, 3, mask3) |
|
features[39] = mean_conv(TOA_REFL_BLUE[None], None, 3, mask3) |
|
|
|
return features |
|
|
|
def feature_generator_batch(X, maskvalid=None): |
|
""" |
|
X: (B, C, H, W) |
|
maskvalid: (H, W) or None |
|
Returns: (B, 40, H, W) |
|
""" |
|
|
|
dims = (X.shape[0], 40, X.shape[-2], X.shape[-1]) |
|
features = torch.zeros(dims, dtype=torch.float32, device=X.device) |
|
|
|
for i in range(X.shape[0]): |
|
features[i] = feature_generator(X[i], maskvalid=maskvalid) |
|
|
|
return features |
|
|
|
|
|
class CloudMaskOne(nn.Module): |
|
def __init__(self, |
|
hidden_layer_sizes=(21, 20), |
|
activation='relu', |
|
last_activation='sigmoid', |
|
dropout_rate=0.0, |
|
input_dim=40, |
|
batch_norm=False): |
|
super().__init__() |
|
self.input_dim = input_dim |
|
|
|
|
|
activations = { |
|
'relu': nn.ReLU(inplace=True), |
|
'tanh': nn.Tanh(), |
|
'leaky_relu': nn.LeakyReLU(inplace=True), |
|
} |
|
self.act = activations[activation] |
|
self.last_act = {'sigmoid': nn.Sigmoid(), |
|
'softmax': nn.Softmax(dim=1), |
|
'identity': nn.Identity()}[last_activation] |
|
|
|
|
|
self.norm_conv = nn.Conv2d(input_dim, 40, 1, bias=not batch_norm) |
|
self.norm_bn = nn.BatchNorm2d(40) if batch_norm else None |
|
|
|
|
|
self.layers = nn.ModuleList() |
|
in_ch = 40 |
|
for out_ch in hidden_layer_sizes: |
|
conv = nn.Conv2d(in_ch, out_ch, 1, bias=not batch_norm) |
|
bn = nn.BatchNorm2d(out_ch) if batch_norm else None |
|
do = nn.Dropout(dropout_rate) if dropout_rate>0 else None |
|
self.layers.append(nn.ModuleDict({'conv':conv, 'bn':bn, 'dropout':do})) |
|
in_ch = out_ch |
|
|
|
|
|
self.final_conv = nn.Conv2d(in_ch, 1, 1) |
|
|
|
def forward(self, x, maskvalid=None): |
|
|
|
if maskvalid is not None: |
|
x = feature_generator_batch(x, maskvalid) |
|
else: |
|
x = feature_generator_batch(x) |
|
|
|
|
|
x = self.norm_conv(x) |
|
if self.norm_bn: |
|
x = self.norm_bn(x) |
|
|
|
|
|
for blk in self.layers: |
|
x = self.act(blk['conv'](x)) |
|
if blk['bn']: |
|
x = blk['bn'](x) |
|
if blk['dropout']: |
|
x = blk['dropout'](x) |
|
|
|
|
|
x = self.final_conv(x) |
|
return self.last_act(x) |
|
|