PROBAandSPOT / ensemble /c2r1km.py
csaybar's picture
Rename ensemble/model_local.py to ensemble/c2r1km.py
62a17d0 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
## Generate feature maps ----------------------------------------------
EPSILONDIV = 1e-4
WLMID = torch.tensor([462, 655.5, 843, 1599], dtype=torch.float32)
BWIDTH = torch.tensor([48, 81, 142, 70], dtype=torch.float32)
## Spectral features ---------------------------------------------------
def safe_divide(numerator, denominator, eps=EPSILONDIV):
denominator = torch.where(denominator < eps, torch.full_like(denominator, eps), denominator)
return numerator / denominator
def NDVI(image): # image: (C, H, W)
nir = image[2:3]
red = image[1:2]
return safe_divide(nir - red, nir + red)
def BLUENIRndsi(image):
blue = image[0:1]
nir = image[2:3]
return safe_divide(blue - nir, blue + nir)
def BLUESWIRndsi(image):
blue = image[0:1]
swir = image[3:4]
return safe_divide(blue - swir, blue + swir)
def REDSWIRratio(image):
red = image[1:2]
swir = image[3:4]
return safe_divide(red, swir)
def trapz(tensor, x): # tensor: (C, H, W), x: (C,)
x = x.to(tensor.device)
left = tensor[:-1]
right = tensor[1:]
dx = (x[1:] - x[:-1]).view(-1, 1, 1)
return torch.sum(dx * (left + right) / 2.0, dim=0, keepdim=True)
def whiteness(image): # image: (C, H, W)
norm = torch.linalg.norm(image, dim=0, keepdim=True)
norm = torch.where(norm < EPSILONDIV, torch.full_like(norm, EPSILONDIV), norm)
normalized = image / norm
ideal = 1.0 / torch.sqrt(torch.tensor(image.shape[0], dtype=torch.float32, device=image.device))
diff = torch.abs(normalized - ideal)
wrange = WLMID[-1] - WLMID[0]
return trapz(diff, WLMID) / wrange
def brightness(image): # image: (C, H, W)
wrange = WLMID[-1] - WLMID[0]
return trapz(image, WLMID) / wrange
def brightnessVIS(image): # BLUE + RED = channels 0 and 1
return brightness(image[0:2])
def brightnessNIR(image): # NIR + SWIR = channels 2 and 3
return brightness(image[2:4])
def whitenessVIS(image):
return whiteness(image[0:2])
def whitenessNIR(image):
return whiteness(image[2:4])
## Spatial features ---------------------------------------------------
def centered_avg_pool(x: torch.Tensor, size: int) -> torch.Tensor:
"""
x: (B, C, H, W)
size: pooling window size
Returns: same shape (B, C, H, W), average over a centered size×size patch.
"""
# pad H and W by size//2 on both sides, zero pad
pad = size // 2
# F.pad takes (pad_left, pad_right, pad_top, pad_bottom)
x_padded = F.pad(x, (pad, pad, pad, pad), mode="replicate")
# avg_pool2d with kernel=size, stride=1
return F.avg_pool2d(x_padded, kernel_size=size, stride=1)
def mconvolution(original_layer: torch.Tensor,
size: int,
maskconv: torch.Tensor = None) -> torch.Tensor:
"""
Mean convolution: centered average, optionally divided by maskconv.
maskconv: same shape (B, C, H, W) or broadcastable
"""
avg = centered_avg_pool(original_layer, size)
if maskconv is not None:
avg = avg / maskconv
return avg
def sconvolution(original_layer: torch.Tensor,
mean_layer: torch.Tensor,
size: int,
maskconv: torch.Tensor = None,
zeros: torch.Tensor = None) -> torch.Tensor:
"""
Standard-deviation-like convolution:
sqrt( E[x^2] - (E[x])^2 ) where E is centered average.
zeros: tensor of zeros (B, C, H, W) or broadcastable; defaults to zeros_like(original_layer)
"""
# E[x^2]
avg2 = centered_avg_pool(original_layer * original_layer, size)
if maskconv is not None:
avg2 = avg2 / maskconv
# (E[x])^2
mean_sq = mean_layer * mean_layer
# prepare zeros
if zeros is None:
zeros = torch.zeros_like(original_layer)
# sqrt where positive
diff = avg2 - mean_sq
out = torch.sqrt(torch.clamp(diff, min=0.0))
# mask negatives to zero
return torch.where(diff > 0, out, zeros)
def standard_deviation_conv(X, fun=None, size=5, maskconv=None):
# Estimate the pixel wise feature
if fun is not None:
X = fun(X)
# Estimate the mean
mean_layer = mconvolution(X, size, maskconv=maskconv)
# Estimate the standard deviation
std_layer = sconvolution(X, mean_layer, size, maskconv=maskconv)
return std_layer
def mean_conv(X, fun=None, size=5, maskconv=None):
# Estimate the pixel wise feature
if fun is not None:
X = fun(X)
# Estimate the mean
mean_layer = mconvolution(X, size, maskconv=maskconv)
return mean_layer
def covolvemask(maskvalid: torch.Tensor, size: int) -> torch.Tensor:
"""
maskvalid: torch.Tensor of dtype torch.bool or 0/1 ints, shape (H, W)
size: the window size for centered avg pooling
Returns: torch.FloatTensor of shape (H, W)
"""
# 1) to float
mask_f = maskvalid.to(torch.float32).unsqueeze(0).unsqueeze(0)
# 2) centered average pool
mask_cov = centered_avg_pool(mask_f, size)
# 3) where maskvalid==1, keep mask_cov; else set to 1
ones = torch.ones_like(mask_cov)
return torch.where(maskvalid, mask_cov, ones)
def feature_generator(X, maskvalid=None):
# Generate container
dims = (40, X.shape[-2], X.shape[-1])
features = torch.zeros(dims, dtype=torch.float32, device=X.device)
# Identify the bands
TOA_REFL_BLUE = X[0]
TOA_REFL_RED = X[1]
TOA_REFL_NIR = X[2]
TOA_REFL_SWIR = X[3]
# TOA ONLY features
features[0] = TOA_REFL_BLUE
features[30] = TOA_REFL_RED
# Spectral features
features[2] = whitenessVIS(X)
features[3] = REDSWIRratio(X)
features[7] = BLUENIRndsi(X)
features[13] = NDVI(X)
features[24] = BLUESWIRndsi(X)
features[25] = whitenessNIR(X)
features[36] = whiteness(X)
features[37] = brightnessVIS(X)
# Spatial features
## s5
if maskvalid is not None:
mask5 = covolvemask(maskvalid, 5)
mask3 = covolvemask(maskvalid, 3)
else:
mask5 = None
mask3 = None
features[1] = standard_deviation_conv(X, whitenessVIS, 5, mask5)
features[5] = standard_deviation_conv(X, NDVI, 5, mask5)
features[11] = standard_deviation_conv(TOA_REFL_SWIR[None], None, 5, mask5)
features[12] = standard_deviation_conv(X, REDSWIRratio, 5, mask5)
features[15] = standard_deviation_conv(X, whiteness, 5, mask5)
features[17] = standard_deviation_conv(TOA_REFL_BLUE[None], None, 5, mask5)
features[18] = standard_deviation_conv(X, BLUESWIRndsi, 5, mask5)
features[21] = standard_deviation_conv(X, BLUENIRndsi, 5, mask5)
features[27] = standard_deviation_conv(X, whitenessVIS, 5, mask5)
features[29] = standard_deviation_conv(X, brightnessNIR, 5, mask5)
features[32] = standard_deviation_conv(TOA_REFL_RED[None], None, 5, mask5)
features[35] = standard_deviation_conv(X, brightness, 5, mask5)
features[38] = standard_deviation_conv(TOA_REFL_NIR[None], None, 5, mask5)
## s3
features[10] = standard_deviation_conv(X, whitenessNIR, 3, mask3)
features[19] = standard_deviation_conv(X, REDSWIRratio, 3, mask3)
features[20] = standard_deviation_conv(X, NDVI, 3, mask3)
features[23] = standard_deviation_conv(X, whitenessVIS, 3, mask3)
features[26] = standard_deviation_conv(TOA_REFL_BLUE[None], None, 3, mask3)
features[28] = standard_deviation_conv(X, whiteness, 3, mask3)
features[31] = standard_deviation_conv(X, BLUESWIRndsi, 3, mask3)
features[34] = standard_deviation_conv(X, BLUENIRndsi, 3, mask3)
## m5
features[4] = mean_conv(X, NDVI, 5, mask5)
features[6] = mean_conv(X, REDSWIRratio, 5, mask5)
features[9] = mean_conv(X, whitenessVIS, 5, mask5)
features[16] = mean_conv(X, whitenessNIR, 5, mask5)
features[22] = mean_conv(X, BLUENIRndsi, 5, mask5)
## m3
features[8] = mean_conv(X, whitenessVIS, 3, mask3)
features[14] = mean_conv(X, NDVI, 3, mask3)
features[33] = mean_conv(X, BLUENIRndsi, 3, mask3)
features[39] = mean_conv(TOA_REFL_BLUE[None], None, 3, mask3)
return features
def feature_generator_batch(X, maskvalid=None):
"""
X: (B, C, H, W)
maskvalid: (H, W) or None
Returns: (B, 40, H, W)
"""
# Generate container
dims = (X.shape[0], 40, X.shape[-2], X.shape[-1])
features = torch.zeros(dims, dtype=torch.float32, device=X.device)
for i in range(X.shape[0]):
features[i] = feature_generator(X[i], maskvalid=maskvalid)
return features
## Model Torch -------------------------------------------------------
class CloudMaskOne(nn.Module):
def __init__(self,
hidden_layer_sizes=(21, 20),
activation='relu',
last_activation='sigmoid',
dropout_rate=0.0,
input_dim=40,
batch_norm=False):
super().__init__()
self.input_dim = input_dim
# Activation maps
activations = {
'relu': nn.ReLU(inplace=True),
'tanh': nn.Tanh(),
'leaky_relu': nn.LeakyReLU(inplace=True),
}
self.act = activations[activation]
self.last_act = {'sigmoid': nn.Sigmoid(),
'softmax': nn.Softmax(dim=1),
'identity': nn.Identity()}[last_activation]
# normalization conv (no activation)
self.norm_conv = nn.Conv2d(input_dim, 40, 1, bias=not batch_norm)
self.norm_bn = nn.BatchNorm2d(40) if batch_norm else None
# hidden layers: conv with activation, then batch‐norm, then dropout
self.layers = nn.ModuleList()
in_ch = 40
for out_ch in hidden_layer_sizes:
conv = nn.Conv2d(in_ch, out_ch, 1, bias=not batch_norm)
bn = nn.BatchNorm2d(out_ch) if batch_norm else None
do = nn.Dropout(dropout_rate) if dropout_rate>0 else None
self.layers.append(nn.ModuleDict({'conv':conv, 'bn':bn, 'dropout':do}))
in_ch = out_ch
# final conv + sigmoid (no BN or dropout)
self.final_conv = nn.Conv2d(in_ch, 1, 1)
def forward(self, x, maskvalid=None):
# Generate the features
if maskvalid is not None:
x = feature_generator_batch(x, maskvalid)
else:
x = feature_generator_batch(x)
# normalization conv (no activation)
x = self.norm_conv(x)
if self.norm_bn:
x = self.norm_bn(x)
# hidden blocks: conv+act → bn → dropout
for blk in self.layers:
x = self.act(blk['conv'](x))
if blk['bn']:
x = blk['bn'](x)
if blk['dropout']:
x = blk['dropout'](x)
# final conv+sigmoid
x = self.final_conv(x)
return self.last_act(x)