import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange class GlobalContextBlock(nn.Module): def __init__( self, in_channels: int, out_channels: int, min_channels: int = 16, init_bias: float = -10., fusion_type: str = "mul", ): super().__init__() assert fusion_type in ("mul", "add"), f"Unsupported fusion type: {fusion_type}" self.fusion_type = fusion_type self.conv_ctx = nn.Conv2d(in_channels, 1, kernel_size=1) num_channels = max(min_channels, out_channels // 2) if fusion_type == "mul": self.conv_mul = nn.Sequential( nn.Conv2d(in_channels, num_channels, kernel_size=1), nn.LayerNorm([num_channels, 1, 1]), # TODO: LayerNorm or GroupNorm? nn.LeakyReLU(0.1), nn.Conv2d(num_channels, out_channels, kernel_size=1), nn.Sigmoid(), ) nn.init.zeros_(self.conv_mul[-2].weight) nn.init.constant_(self.conv_mul[-2].bias, init_bias) else: self.conv_add = nn.Sequential( nn.Conv2d(in_channels, num_channels, kernel_size=1), nn.LayerNorm([num_channels, 1, 1]), # TODO: LayerNorm or GroupNorm? nn.LeakyReLU(0.1), nn.Conv2d(num_channels, out_channels, kernel_size=1), ) nn.init.zeros_(self.conv_add[-1].weight) nn.init.constant_(self.conv_add[-1].bias, init_bias) def forward(self, x: torch.Tensor) -> torch.Tensor: is_image = x.ndim == 4 if is_image: x = rearrange(x, "b c h w -> b c 1 h w") # x: (B, C, T, H, W) orig_x = x batch_size = x.shape[0] x = rearrange(x, "b c t h w -> (b t) c h w") ctx = self.conv_ctx(x) ctx = rearrange(ctx, "b c h w -> b c (h w)") ctx = F.softmax(ctx, dim=-1) flattened_x = rearrange(x, "b c h w -> b c (h w)") x = torch.einsum("b c1 n, b c2 n -> b c2 c1", ctx, flattened_x) x = rearrange(x, "... -> ... 1") if self.fusion_type == "mul": mul_term = self.conv_mul(x) mul_term = rearrange(mul_term, "(b t) c h w -> b c t h w", b=batch_size) x = orig_x * mul_term else: add_term = self.conv_add(x) add_term = rearrange(add_term, "(b t) c h w -> b c t h w", b=batch_size) x = orig_x + add_term if is_image: x = rearrange(x, "b c 1 h w -> b c h w") return x