|
import torch
|
|
import torch.nn as nn
|
|
import einops
|
|
|
|
from torch.nn import functional as F
|
|
from torch.jit import Final
|
|
from timm.layers import use_fused_attn
|
|
from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, get_act_layer
|
|
|
|
__all__ = ['SVDNoiseUnet', 'SVDNoiseUnet_Concise']
|
|
|
|
class Attention(nn.Module):
|
|
fused_attn: Final[bool]
|
|
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
num_heads: int = 8,
|
|
qkv_bias: bool = False,
|
|
qk_norm: bool = False,
|
|
attn_drop: float = 0.,
|
|
proj_drop: float = 0.,
|
|
norm_layer: nn.Module = nn.LayerNorm,
|
|
) -> None:
|
|
super().__init__()
|
|
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
|
self.num_heads = num_heads
|
|
self.head_dim = dim // num_heads
|
|
self.scale = self.head_dim ** -0.5
|
|
self.fused_attn = use_fused_attn()
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
|
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
|
self.attn_drop = nn.Dropout(attn_drop)
|
|
self.proj = nn.Linear(dim, dim)
|
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
B, N, C = x.shape
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
|
q, k, v = qkv.unbind(0)
|
|
q, k = self.q_norm(q), self.k_norm(k)
|
|
|
|
if self.fused_attn:
|
|
x = F.scaled_dot_product_attention(
|
|
q, k, v,
|
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
|
)
|
|
else:
|
|
q = q * self.scale
|
|
attn = q @ k.transpose(-2, -1)
|
|
attn = attn.softmax(dim=-1)
|
|
attn = self.attn_drop(attn)
|
|
x = attn @ v
|
|
|
|
x = x.transpose(1, 2).reshape(B, N, C)
|
|
x = self.proj(x)
|
|
x = self.proj_drop(x)
|
|
return x
|
|
|
|
|
|
class SVDNoiseUnet(nn.Module):
|
|
def __init__(self, in_channels=4, out_channels=4, resolution=128):
|
|
super(SVDNoiseUnet, self).__init__()
|
|
|
|
_in = int(resolution * in_channels // 2)
|
|
_out = int(resolution * out_channels // 2)
|
|
self.mlp1 = nn.Sequential(
|
|
nn.Linear(_in, 64),
|
|
nn.ReLU(inplace=True),
|
|
nn.Linear(64, _out),
|
|
)
|
|
self.mlp2 = nn.Sequential(
|
|
nn.Linear(_in, 64),
|
|
nn.ReLU(inplace=True),
|
|
nn.Linear(64, _out),
|
|
)
|
|
|
|
self.mlp3 = nn.Sequential(
|
|
nn.Linear(_in, _out),
|
|
)
|
|
|
|
self.attention = Attention(_out)
|
|
|
|
self.bn = nn.BatchNorm2d(_out)
|
|
|
|
self.mlp4 = nn.Sequential(
|
|
nn.Linear(_out, 1024),
|
|
nn.ReLU(inplace=True),
|
|
nn.Linear(1024, _out),
|
|
)
|
|
|
|
def forward(self, x, residual=False):
|
|
b, c, h, w = x.shape
|
|
x = einops.rearrange(x, "b (a c)h w ->b (a h)(c w)", a=2,c=2)
|
|
U, s, V = torch.linalg.svd(x)
|
|
U_T = U.permute(0, 2, 1)
|
|
out = self.mlp1(U_T) + self.mlp2(V) + self.mlp3(s).unsqueeze(1)
|
|
out = self.attention(out).mean(1)
|
|
out = self.mlp4(out) + s
|
|
pred = U @ torch.diag_embed(out) @ V
|
|
return einops.rearrange(pred, "b (a h)(c w) -> b (a c) h w", a=2,c=2)
|
|
|
|
|
|
class SVDNoiseUnet_Concise(nn.Module):
|
|
def __init__(self, in_channels=4, out_channels=4, resolution=128):
|
|
super(SVDNoiseUnet_Concise, self).__init__()
|
|
|
|
|