bubbliiiing
Create Code
19fe404
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class GlobalContextBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
min_channels: int = 16,
init_bias: float = -10.,
fusion_type: str = "mul",
):
super().__init__()
assert fusion_type in ("mul", "add"), f"Unsupported fusion type: {fusion_type}"
self.fusion_type = fusion_type
self.conv_ctx = nn.Conv2d(in_channels, 1, kernel_size=1)
num_channels = max(min_channels, out_channels // 2)
if fusion_type == "mul":
self.conv_mul = nn.Sequential(
nn.Conv2d(in_channels, num_channels, kernel_size=1),
nn.LayerNorm([num_channels, 1, 1]), # TODO: LayerNorm or GroupNorm?
nn.LeakyReLU(0.1),
nn.Conv2d(num_channels, out_channels, kernel_size=1),
nn.Sigmoid(),
)
nn.init.zeros_(self.conv_mul[-2].weight)
nn.init.constant_(self.conv_mul[-2].bias, init_bias)
else:
self.conv_add = nn.Sequential(
nn.Conv2d(in_channels, num_channels, kernel_size=1),
nn.LayerNorm([num_channels, 1, 1]), # TODO: LayerNorm or GroupNorm?
nn.LeakyReLU(0.1),
nn.Conv2d(num_channels, out_channels, kernel_size=1),
)
nn.init.zeros_(self.conv_add[-1].weight)
nn.init.constant_(self.conv_add[-1].bias, init_bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
is_image = x.ndim == 4
if is_image:
x = rearrange(x, "b c h w -> b c 1 h w")
# x: (B, C, T, H, W)
orig_x = x
batch_size = x.shape[0]
x = rearrange(x, "b c t h w -> (b t) c h w")
ctx = self.conv_ctx(x)
ctx = rearrange(ctx, "b c h w -> b c (h w)")
ctx = F.softmax(ctx, dim=-1)
flattened_x = rearrange(x, "b c h w -> b c (h w)")
x = torch.einsum("b c1 n, b c2 n -> b c2 c1", ctx, flattened_x)
x = rearrange(x, "... -> ... 1")
if self.fusion_type == "mul":
mul_term = self.conv_mul(x)
mul_term = rearrange(mul_term, "(b t) c h w -> b c t h w", b=batch_size)
x = orig_x * mul_term
else:
add_term = self.conv_add(x)
add_term = rearrange(add_term, "(b t) c h w -> b c t h w", b=batch_size)
x = orig_x + add_term
if is_image:
x = rearrange(x, "b c 1 h w -> b c h w")
return x