XavierJiezou's picture
Upload folder using huggingface_hub
3c8ff2e verified
"""
UnCRtainTS Implementation
Author: Patrick Ebel (github/patrickTUM)
License: MIT
"""
import torch
import torch.nn as nn
import sys
sys.path.append("./model")
from src.backbones.utae import ConvLayer, ConvBlock, TemporallySharedBlock
from src.backbones.ltae import LTAE2d, LTAE2dtiny
S2_BANDS = 13
def get_norm_layer(out_channels, num_feats, n_groups=4, layer_type='batch'):
if layer_type == 'batch':
return nn.BatchNorm2d(out_channels)
elif layer_type == 'instance':
return nn.InstanceNorm2d(out_channels)
elif layer_type == 'group':
return nn.GroupNorm(num_channels=num_feats, num_groups=n_groups)
class ResidualConvBlock(TemporallySharedBlock):
def __init__(
self,
nkernels,
pad_value=None,
norm="batch",
n_groups=4,
#last_relu=True,
k=3, s=1, p=1,
padding_mode="reflect",
):
super(ResidualConvBlock, self).__init__(pad_value=pad_value)
self.conv1 = ConvLayer(
nkernels=nkernels,
norm=norm,
last_relu=True,
k=k, s=s, p=p,
n_groups=n_groups,
padding_mode=padding_mode,
)
self.conv2 = ConvLayer(
nkernels=nkernels,
norm=norm,
last_relu=True,
k=k, s=s, p=p,
n_groups=n_groups,
padding_mode=padding_mode,
)
self.conv3 = ConvLayer(
nkernels=nkernels,
#norm='none',
#last_relu=False,
norm=norm,
last_relu=True,
k=k, s=s, p=p,
n_groups=n_groups,
padding_mode=padding_mode,
)
def forward(self, input):
out1 = self.conv1(input) # followed by built-in ReLU & norm
out2 = self.conv2(out1) # followed by built-in ReLU & norm
out3 = input + self.conv3(out2) # omit norm & ReLU
return out3
class PreNorm(nn.Module):
def __init__(self, dim, fn, norm, n_groups=4):
super().__init__()
self.norm = get_norm_layer(dim, dim, n_groups, norm)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class SE(nn.Module):
def __init__(self, inp, oup, expansion=0.25):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(oup, int(inp * expansion), bias=False),
nn.GELU(),
nn.Linear(int(inp * expansion), oup, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y
class MBConv(TemporallySharedBlock):
def __init__(self, inp, oup, downsample=False, expansion=4, norm='batch', n_groups=4):
super().__init__()
self.downsample = downsample
stride = 1 if self.downsample == False else 2
hidden_dim = int(inp * expansion)
if self.downsample:
self.pool = nn.MaxPool2d(3, 2, 1)
self.proj = nn.Conv2d(inp, oup, 1, stride=1, padding=0, bias=False)
if expansion == 1:
self.conv = nn.Sequential(
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride=stride,
padding=1, padding_mode='reflect', groups=hidden_dim, bias=False),
get_norm_layer(hidden_dim, hidden_dim, n_groups, norm),
nn.GELU(),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, stride=1, padding=0, bias=False),
get_norm_layer(oup, oup, n_groups, norm),
)
else:
self.conv = nn.Sequential(
# pw
# down-sample in the first conv
nn.Conv2d(inp, hidden_dim, 1, stride=stride, padding=0, bias=False),
get_norm_layer(hidden_dim, hidden_dim, n_groups, norm),
nn.GELU(),
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride=1, padding=1, padding_mode='reflect',
groups=hidden_dim, bias=False),
get_norm_layer(hidden_dim, hidden_dim, n_groups, norm),
nn.GELU(),
SE(inp, hidden_dim),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, stride=1, padding=0, bias=False),
get_norm_layer(oup, oup, n_groups, norm),
)
self.conv = PreNorm(inp, self.conv, norm, n_groups=4)
def forward(self, x):
if self.downsample:
return self.proj(self.pool(x)) + self.conv(x)
else:
return x + self.conv(x)
class Compact_Temporal_Aggregator(nn.Module):
def __init__(self, mode="mean"):
super(Compact_Temporal_Aggregator, self).__init__()
self.mode = mode
# moved dropout from ScaledDotProductAttention to here, applied after upsampling
self.attn_dropout = nn.Dropout(0.1) # no dropout via: nn.Dropout(0.0)
def forward(self, x, pad_mask=None, attn_mask=None):
if pad_mask is not None and pad_mask.any():
if self.mode == "att_group":
n_heads, b, t, h, w = attn_mask.shape
attn = attn_mask.view(n_heads * b, t, h, w)
if x.shape[-2] > w:
attn = nn.Upsample(
size=x.shape[-2:], mode="bilinear", align_corners=False
)(attn)
# this got moved out of ScaledDotProductAttention, apply after upsampling
attn = self.attn_dropout(attn)
else:
attn = nn.AvgPool2d(kernel_size=w // x.shape[-2])(attn)
attn = attn.view(n_heads, b, t, *x.shape[-2:])
attn = attn * (~pad_mask).float()[None, :, :, None, None]
out = torch.stack(x.chunk(n_heads, dim=2)) # hxBxTxC/hxHxW
out = attn[:, :, :, None, :, :] * out
out = out.sum(dim=2) # sum on temporal dim -> hxBxC/hxHxW
out = torch.cat([group for group in out], dim=1) # -> BxCxHxW
return out
elif self.mode == "att_mean":
attn = attn_mask.mean(dim=0) # average over heads -> BxTxHxW
attn = nn.Upsample(
size=x.shape[-2:], mode="bilinear", align_corners=False
)(attn)
# this got moved out of ScaledDotProductAttention, apply after upsampling
attn = self.attn_dropout(attn)
attn = attn * (~pad_mask).float()[:, :, None, None]
out = (x * attn[:, :, None, :, :]).sum(dim=1)
return out
elif self.mode == "mean":
out = x * (~pad_mask).float()[:, :, None, None, None]
out = out.sum(dim=1) / (~pad_mask).sum(dim=1)[:, None, None, None]
return out
else:
if self.mode == "att_group":
n_heads, b, t, h, w = attn_mask.shape
attn = attn_mask.view(n_heads * b, t, h, w)
if x.shape[-2] > w:
attn = nn.Upsample(
size=x.shape[-2:], mode="bilinear", align_corners=False
)(attn)
# this got moved out of ScaledDotProductAttention, apply after upsampling
attn = self.attn_dropout(attn)
else:
attn = nn.AvgPool2d(kernel_size=w // x.shape[-2])(attn)
attn = attn.view(n_heads, b, t, *x.shape[-2:])
out = torch.stack(x.chunk(n_heads, dim=2)) # hxBxTxC/hxHxW
out = attn[:, :, :, None, :, :] * out
out = out.sum(dim=2) # sum on temporal dim -> hxBxC/hxHxW
out = torch.cat([group for group in out], dim=1) # -> BxCxHxW
return out
elif self.mode == "att_mean":
attn = attn_mask.mean(dim=0) # average over heads -> BxTxHxW
attn = nn.Upsample(
size=x.shape[-2:], mode="bilinear", align_corners=False
)(attn)
# this got moved out of ScaledDotProductAttention, apply after upsampling
attn = self.attn_dropout(attn)
out = (x * attn[:, :, None, :, :]).sum(dim=1)
return out
elif self.mode == "mean":
return x.mean(dim=1)
def get_nonlinearity(mode, eps):
if mode=='relu': fct = nn.ReLU()
elif mode=='softplus': fct = lambda vars:nn.Softplus(beta=1, threshold=20)(vars) + eps
elif mode=='elu': fct = lambda vars: nn.ELU()(vars) + 1 + eps
else: fct = nn.Identity()
return fct
# class UNCRTAINTS(nn.Module):
# def __init__(
# self,
# input_dim,
# encoder_widths=[128],
# decoder_widths=[128,128,128,128,128],
# out_conv=[S2_BANDS],
# out_nonlin_mean=False,
# out_nonlin_var='relu',
# agg_mode="att_group",
# encoder_norm="group",
# decoder_norm="batch",
# n_head=16,
# d_model=256,
# d_k=4,
# pad_value=0,
# padding_mode="reflect",
# positional_encoding=True,
# covmode='diag',
# scale_by=1,
# separate_out=False,
# use_v=False,
# block_type='mbconv',
# is_mono=False
# ):
# """
# UnCRtainTS architecture for spatio-temporal encoding of satellite image time series.
# Args:
# input_dim (int): Number of channels in the input images.
# encoder_widths (List[int]): List giving the number of channels of the successive encoder_widths of the convolutional encoder.
# This argument also defines the number of encoder_widths (i.e. the number of downsampling steps +1)
# in the architecture.
# The number of channels are given from top to bottom, i.e. from the highest to the lowest resolution.
# decoder_widths (List[int], optional): Same as encoder_widths but for the decoder. The order in which the number of
# channels should be given is also from top to bottom. If this argument is not specified the decoder
# will have the same configuration as the encoder.
# out_conv (List[int]): Number of channels of the successive convolutions for the
# agg_mode (str): Aggregation mode for the skip connections. Can either be:
# - att_group (default) : Attention weighted temporal average, using the same
# channel grouping strategy as in the LTAE. The attention masks are bilinearly
# resampled to the resolution of the skipped feature maps.
# - att_mean : Attention weighted temporal average,
# using the average attention scores across heads for each date.
# - mean : Temporal average excluding padded dates.
# encoder_norm (str): Type of normalisation layer to use in the encoding branch. Can either be:
# - group : GroupNorm (default)
# - batch : BatchNorm
# - instance : InstanceNorm
# - none: apply no normalization
# decoder_norm (str): similar to encoder_norm
# n_head (int): Number of heads in LTAE.
# d_model (int): Parameter of LTAE
# d_k (int): Key-Query space dimension
# pad_value (float): Value used by the dataloader for temporal padding.
# padding_mode (str): Spatial padding strategy for convolutional layers (passed to nn.Conv2d).
# positional_encoding (bool): If False, no positional encoding is used (default True).
# """
# super(UNCRTAINTS, self).__init__()
# self.n_stages = len(encoder_widths)
# self.encoder_widths = encoder_widths
# self.decoder_widths = decoder_widths
# self.out_widths = out_conv
# self.is_mono = is_mono
# self.use_v = use_v
# self.block_type = block_type
# self.enc_dim = decoder_widths[0] if decoder_widths is not None else encoder_widths[0]
# self.stack_dim = sum(decoder_widths) if decoder_widths is not None else sum(encoder_widths)
# self.pad_value = pad_value
# self.padding_mode = padding_mode
# self.scale_by = scale_by
# self.separate_out = separate_out # define two separate layer streams for mean and variance predictions
# if decoder_widths is not None:
# assert encoder_widths[-1] == decoder_widths[-1]
# else: decoder_widths = encoder_widths
# # ENCODER
# self.in_conv = ConvBlock(
# nkernels=[input_dim] + [encoder_widths[0]],
# k=1, s=1, p=0,
# norm=encoder_norm,
# )
# if self.block_type=='mbconv':
# self.in_block = nn.ModuleList([MBConv(layer, layer, downsample=False, expansion=2, norm=encoder_norm) for layer in encoder_widths])
# elif self.block_type=='residual':
# self.in_block = nn.ModuleList([ResidualConvBlock(nkernels=[layer]+[layer], k=3, s=1, p=1, norm=encoder_norm, n_groups=4) for layer in encoder_widths])
# else: raise NotImplementedError
# if not self.is_mono:
# # LTAE
# if self.use_v:
# # same as standard LTAE, except we don't apply dropout on the low-resolution attention masks
# self.temporal_encoder = LTAE2d(
# in_channels=encoder_widths[0],
# d_model=d_model,
# n_head=n_head,
# mlp=[d_model, encoder_widths[0]], # MLP to map v, only used if self.use_v=True
# return_att=True,
# d_k=d_k,
# positional_encoding=positional_encoding,
# use_dropout=False
# )
# # linearly combine mask-weighted
# v_dim = encoder_widths[0]
# self.include_v = nn.Conv2d(encoder_widths[0]+v_dim, encoder_widths[0], 1)
# else:
# self.temporal_encoder = LTAE2dtiny(
# in_channels=encoder_widths[0],
# d_model=d_model,
# n_head=n_head,
# d_k=d_k,
# positional_encoding=positional_encoding,
# )
# self.temporal_aggregator = Compact_Temporal_Aggregator(mode=agg_mode)
# if self.block_type=='mbconv':
# self.out_block = nn.ModuleList([MBConv(layer, layer, downsample=False, expansion=2, norm=decoder_norm) for layer in decoder_widths])
# elif self.block_type=='residual':
# self.out_block = nn.ModuleList([ResidualConvBlock(nkernels=[layer]+[layer], k=3, s=1, p=1, norm=decoder_norm, n_groups=4) for layer in decoder_widths])
# else: raise NotImplementedError
# self.covmode = covmode
# if covmode=='uni':
# # batching across channel dimension
# covar_dim = S2_BANDS
# elif covmode=='iso':
# covar_dim = 1
# elif covmode=='diag':
# covar_dim = S2_BANDS
# else: covar_dim = 0
# self.mean_idx = S2_BANDS
# self.vars_idx = self.mean_idx + covar_dim
# # note: not including normalization layer and ReLU nonlinearity into the final ConvBlock
# # if inserting >1 layers into out_conv then consider treating normalizations separately
# self.out_dims = out_conv[-1]
# eps = 1e-9 if self.scale_by==1.0 else 1e-3
# if self.separate_out: # define two separate layer streams for mean and variance predictions
# self.out_conv_mean_1 = ConvBlock(nkernels=[decoder_widths[0]] + [S2_BANDS], k=1, s=1, p=0, norm='none', last_relu=False)
# if self.out_dims - self.mean_idx > 0:
# self.out_conv_var_1 = ConvBlock(nkernels=[decoder_widths[0]] + [self.out_dims - S2_BANDS], k=1, s=1, p=0, norm='none', last_relu=False)
# else:
# self.out_conv = ConvBlock(nkernels=[decoder_widths[0]] + out_conv, k=1, s=1, p=0, norm='none', last_relu=False)
# # set output nonlinearities
# if out_nonlin_mean: self.out_mean = lambda vars: self.scale_by * nn.Sigmoid()(vars) # this is for predicting mean values in [0, 1]
# else: self.out_mean = nn.Identity() # just keep the mean estimates, without applying a nonlinearity
# if self.covmode in ['uni', 'iso', 'diag']:
# self.diag_var = get_nonlinearity(out_nonlin_var, eps)
# def forward(self, input, batch_positions=None):
# print(input.shape)
# pad_mask = (
# (input == self.pad_value).all(dim=-1).all(dim=-1).all(dim=-1)
# ) # BxT pad mask
# # SPATIAL ENCODER
# # collect feature maps in list 'feature_maps'
# out = self.in_conv.smart_forward(input)
# for layer in self.in_block:
# out = layer.smart_forward(out)
# if not self.is_mono:
# att_down = 32
# down = nn.AdaptiveMaxPool2d((att_down, att_down))(out.view(out.shape[0] * out.shape[1], *out.shape[2:])).view(out.shape[0], out.shape[1], out.shape[2], att_down, att_down)
# # TEMPORAL ENCODER
# if self.use_v:
# v, att = self.temporal_encoder(down, batch_positions=batch_positions, pad_mask=pad_mask)
# else:
# att = self.temporal_encoder(down, batch_positions=batch_positions, pad_mask=pad_mask)
# out = self.temporal_aggregator(out, pad_mask=pad_mask, attn_mask=att)
# if self.use_v:
# # upsample values to input resolution, then linearly combine with attention masks
# up_v = nn.Upsample(size=(out.shape[-2:]), mode="bilinear", align_corners=False)(v)
# out = self.include_v(torch.cat((out, up_v), dim=1))
# else: out = out.squeeze(dim=1)
# # SPATIAL DECODER
# for layer in self.out_block:
# out = layer.smart_forward(out)
# if self.separate_out:
# out_mean_1 = self.out_conv_mean_1(out)
# if self.out_dims - self.mean_idx > 0:
# out_var_1 = self.out_conv_var_1(out)
# out = torch.cat((out_mean_1, out_var_1), dim=1)
# else: out = out_mean_1 #out = out_mean_2
# else:
# out = self.out_conv(out) # predict mean and var in single layer
# # append a singelton temporal dimension such that outputs are [B x T=1 x C x H x W]
# out = out.unsqueeze(dim=1)
# # apply output nonlinearities
# # get mean predictions
# out_loc = self.out_mean(out[:,:,:self.mean_idx,...]) # mean predictions in [0,1]
# if not self.covmode: return out_loc
# out_cov = self.diag_var(out[:,:,self.mean_idx:self.vars_idx,...]) # var predictions > 0
# out = torch.cat((out_loc, out_cov), dim=2) # stack mean and var predictions plus cloud masks
# print(f"{out.shape}")
# return out
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
import math
from abc import abstractmethod
class EmbedBlock(nn.Module):
"""
Any module where forward() takes embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` embeddings.
"""
class EmbedSequential(nn.Sequential, EmbedBlock):
"""
A sequential module that passes embeddings to the children that
support it as an extra input.
"""
def forward(self, x, emb):
for layer in self:
if isinstance(layer, EmbedBlock):
x = layer(x, emb)
else:
x = layer(x)
return x
def gamma_embedding(gammas, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param gammas: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0,
end=half, dtype=torch.float32) / half
).to(device=gammas.device)
args = gammas[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
class LayerNormFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, bias, eps):
ctx.eps = eps
N, C, H, W = x.size()
mu = x.mean(1, keepdim=True)
var = (x - mu).pow(2).mean(1, keepdim=True)
y = (x - mu) / (var + eps).sqrt()
ctx.save_for_backward(y, var, weight)
y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
return y
@staticmethod
def backward(ctx, grad_output):
eps = ctx.eps
N, C, H, W = grad_output.size()
y, var, weight = ctx.saved_variables
g = grad_output * weight.view(1, C, 1, 1)
mean_g = g.mean(dim=1, keepdim=True)
mean_gy = (g * y).mean(dim=1, keepdim=True)
gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
dim=0), None
class LayerNorm2d(nn.Module):
def __init__(self, channels, eps=1e-6):
super(LayerNorm2d, self).__init__()
self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
self.eps = eps
def forward(self, x):
return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
class SimpleGate(nn.Module):
def forward(self, x):
x1, x2 = x.chunk(2, dim=1)
return x1 * x2
class CondNAFBlock(nn.Module):
def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
super().__init__()
dw_channel = c * DW_Expand
self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel,
kernel_size=1, padding=0, stride=1, groups=1, bias=True)
self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel,
bias=True)
self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c,
kernel_size=1, padding=0, stride=1, groups=1, bias=True)
# Simplified Channel Attention
# self.sca = nn.Sequential(
# nn.AdaptiveAvgPool2d(1),
# nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
# groups=1, bias=True),
# )
self.sca_avg = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels=dw_channel // 4, out_channels=dw_channel // 4, kernel_size=1, padding=0, stride=1,
groups=1, bias=True),
)
self.sca_max = nn.Sequential(
nn.AdaptiveMaxPool2d(1),
nn.Conv2d(in_channels=dw_channel // 4, out_channels=dw_channel // 4, kernel_size=1, padding=0, stride=1,
groups=1, bias=True),
)
# SimpleGate
self.sg = SimpleGate()
ffn_channel = FFN_Expand * c
self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel,
kernel_size=1, padding=0, stride=1, groups=1, bias=True)
self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c,
kernel_size=1, padding=0, stride=1, groups=1, bias=True)
self.norm1 = LayerNorm2d(c)
self.norm2 = LayerNorm2d(c)
self.dropout1 = nn.Dropout(
drop_out_rate) if drop_out_rate > 0. else nn.Identity()
self.dropout2 = nn.Dropout(
drop_out_rate) if drop_out_rate > 0. else nn.Identity()
self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
self.gamma = nn.Parameter(torch.zeros(
(1, c, 1, 1)), requires_grad=True)
def forward(self, inp):
x = inp
x = self.norm1(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.sg(x)
x_avg, x_max = x.chunk(2, dim=1)
x_avg = self.sca_avg(x_avg)*x_avg
x_max = self.sca_max(x_max)*x_max
x = torch.cat([x_avg, x_max], dim=1)
x = self.conv3(x)
x = self.dropout1(x)
y = inp + x * self.beta
x = self.conv4(self.norm2(y))
x = self.sg(x)
x = self.conv5(x)
x = self.dropout2(x)
return y + x * self.gamma
class NAFBlock(nn.Module):
def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
super().__init__()
dw_channel = c * DW_Expand
self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel,
kernel_size=1, padding=0, stride=1, groups=1, bias=True)
self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel,
bias=True)
self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c,
kernel_size=1, padding=0, stride=1, groups=1, bias=True)
# Simplified Channel Attention
# self.sca = nn.Sequential(
# nn.AdaptiveAvgPool2d(1),
# nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
# groups=1, bias=True),
# )
self.sca_avg = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels=dw_channel // 4, out_channels=dw_channel // 4, kernel_size=1, padding=0, stride=1,
groups=1, bias=True),
)
self.sca_max = nn.Sequential(
nn.AdaptiveMaxPool2d(1),
nn.Conv2d(in_channels=dw_channel // 4, out_channels=dw_channel // 4, kernel_size=1, padding=0, stride=1,
groups=1, bias=True),
)
# SimpleGate
self.sg = SimpleGate()
ffn_channel = FFN_Expand * c
self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel,
kernel_size=1, padding=0, stride=1, groups=1, bias=True)
self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c,
kernel_size=1, padding=0, stride=1, groups=1, bias=True)
self.norm1 = LayerNorm2d(c)
self.norm2 = LayerNorm2d(c)
self.dropout1 = nn.Dropout(
drop_out_rate) if drop_out_rate > 0. else nn.Identity()
self.dropout2 = nn.Dropout(
drop_out_rate) if drop_out_rate > 0. else nn.Identity()
self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
self.gamma = nn.Parameter(torch.zeros(
(1, c, 1, 1)), requires_grad=True)
# self.time_emb = nn.Sequential(
# nn.SiLU(),
# nn.Linear(256, c),
# )
def forward(self, inp):
x = inp
x = self.norm1(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.sg(x)
x_avg, x_max = x.chunk(2, dim=1)
x_avg = self.sca_avg(x_avg)*x_avg
x_max = self.sca_max(x_max)*x_max
x = torch.cat([x_avg, x_max], dim=1)
x = self.conv3(x)
x = self.dropout1(x)
y = inp + x * self.beta
# y = y+self.time_emb(t)[..., None, None]
x = self.conv4(self.norm2(y))
x = self.sg(x)
x = self.conv5(x)
x = self.dropout2(x)
return y + x * self.gamma
class UNCRTAINTS(nn.Module):
def __init__(
self,
input_dim=15,
out_conv=[13],
width=64,
middle_blk_num=1,
enc_blk_nums=[1, 1, 1, 1],
dec_blk_nums=[1, 1, 1, 1],
encoder_widths=[128],
decoder_widths=[128,128,128,128,128],
out_nonlin_mean=False,
out_nonlin_var='relu',
agg_mode="att_group",
encoder_norm="group",
decoder_norm="batch",
n_head=16,
d_model=256,
d_k=4,
pad_value=0,
padding_mode="reflect",
positional_encoding=True,
covmode='diag',
scale_by=1,
separate_out=False,
use_v=False,
block_type='mbconv',
is_mono=False
):
super().__init__()
self.intro = nn.Conv2d(in_channels=input_dim, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
bias=True)
# self.cond_intro = nn.Conv2d(in_channels=img_channel+2, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
# bias=True)
self.ending = nn.Conv2d(in_channels=width, out_channels=out_conv[0], kernel_size=3, padding=1, stride=1, groups=1,
bias=True)
# self.inp_ending = nn.Conv2d(in_channels=img_channel, out_channels=3, kernel_size=3, padding=1, stride=1, groups=1,
# bias=True)
self.encoders = nn.ModuleList()
self.cond_encoders = nn.ModuleList()
self.decoders = nn.ModuleList()
self.middle_blks = nn.ModuleList()
self.ups = nn.ModuleList()
self.downs = nn.ModuleList()
self.cond_downs = nn.ModuleList()
chan = width
for num in enc_blk_nums:
self.encoders.append(
nn.Sequential(
*[NAFBlock(chan) for _ in range(num)]
)
)
self.cond_encoders.append(
nn.Sequential(
*[CondNAFBlock(chan) for _ in range(num)]
)
)
self.downs.append(
nn.Conv2d(chan, 2*chan, 2, 2)
)
# self.cond_downs.append(
# nn.Conv2d(chan, 2*chan, 2, 2)
# )
chan = chan * 2
self.middle_blks = \
nn.Sequential(
*[NAFBlock(chan) for _ in range(middle_blk_num)]
)
for num in dec_blk_nums:
self.ups.append(
nn.Sequential(
nn.Conv2d(chan, chan * 2, 1, bias=False),
nn.PixelShuffle(2)
)
)
chan = chan // 2
self.decoders.append(
nn.Sequential(
*[NAFBlock(chan) for _ in range(num)]
)
)
self.padder_size = 2 ** len(self.encoders)
# self.map = nn.Sequential(
# nn.Linear(64, 256),
# nn.SiLU(),
# nn.Linear(256, 256),
# )
def forward(self, inp, batch_positions):
# inp = self.check_image_size(inp)
inp = inp.squeeze(1)
x = self.intro(inp)
encs = []
for encoder, down in zip(self.encoders, self.downs):
x = encoder(x)
# b, c, h, w = cond.shape
# tmp_cond = cond.view(b//3, 3, c, h, w).sum(dim=1)
# tmp_cond = cond
# x = x + tmp_cond
encs.append(x)
x = down(x)
# cond = cond_down(cond)
x = self.middle_blks(x)
for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
x = up(x)
x = x + enc_skip
x = decoder(x)
x = self.ending(x)
# x = x + self.inp_ending(inp)
# print(x.shape)
return x.unsqueeze(1)
def check_image_size(self, x):
_, _, h, w = x.size()
mod_pad_h = (self.padder_size - h %
self.padder_size) % self.padder_size
mod_pad_w = (self.padder_size - w %
self.padder_size) % self.padder_size
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
return x
if __name__ == '__main__':
# unit test for ground resolution
inp = torch.randn(1, 15, 256, 256)
net = UNCRTAINTS(
input_dim=15,
out_conv=[13],
width=64,
middle_blk_num=1,
enc_blk_nums=[1, 1, 1, 1],
dec_blk_nums=[1, 1, 1, 1],
)
out = net(inp)
assert out.shape == (1, 13, 256, 256)
# from thop import profile
# out_shape = (1, 12, 384, 384)
# input_shape = (1, 13, 384, 384)
# model = DiffCR(
# img_channel=13,
# width=32,
# middle_blk_num=1,
# enc_blk_nums=[1, 1, 1, 1],
# dec_blk_nums=[1, 1, 1, 1],
# )
# # 使用 thop 的 profile 函数来获取 FLOPs 和参数量
# flops, params = profile(model, inputs=(torch.randn(out_shape), torch.ones(1,), torch.randn(input_shape)))
# print(f"FLOPs: {flops / 1e9} G")
# print(f"Parameters: {params / 1e6} M")
# if __name__=='__main__':
# inp = torch.rand(1, 15, 256, 256)
# net = UNCRTAINTS(
# input_dim=15,
# out_conv=[13],
# )
# out = net(inp)
# assert out.shape==(1, 13, 256, 256)