chadavit16-moyen / modeling_chada_vit.py
nicoboou's picture
Upload model
008453e verified
"""
ChAda-ViT (i.e Channel Adaptive ViT) is a variant of ViT that can handle multi-channel images.
"""
import logging
import math
from typing import Optional, Union, Callable
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from torch import Tensor
import torch.nn.functional as F
from torch.nn.modules.module import Module
from torch.nn.modules.activation import MultiheadAttention
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.linear import Linear
from torch.nn.modules.normalization import LayerNorm
from .config_chada_vit import ChAdaViTConfig
def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
if activation == "relu":
return F.relu
elif activation == "gelu":
return F.gelu
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
"""Copy & paste from PyTorch official master until it's in a few official releases - RW
Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
"""
def norm_cdf(x):
"""Computes standard normal cumulative distribution function"""
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
logging.warn(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2,
)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
"""Copy & paste from PyTorch official master until it's in a few official releases - RW
Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
class TransformerEncoderLayer(Module):
r"""
Mostly copied from torch.nn.TransformerEncoderLayer, but with the following changes:
- Added the possibility to retrieve the attention weights
"""
__constants__ = ["batch_first", "norm_first"]
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
layer_norm_eps: float = 1e-5,
batch_first: bool = False,
norm_first: bool = False,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(TransformerEncoderLayer, self).__init__()
self.self_attn = MultiheadAttention(
embed_dim=d_model,
num_heads=nhead,
dropout=dropout,
batch_first=batch_first,
**factory_kwargs,
)
# Implementation of Feedforward model
self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
self.dropout = Dropout(dropout)
self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)
self.norm_first = norm_first
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
# Legacy string support for activation function.
if isinstance(activation, str):
activation = _get_activation_fn(activation)
# We can't test self.activation in forward() in TorchScript,
# so stash some information about it instead.
if activation is F.relu:
self.activation_relu_or_gelu = 1
elif activation is F.gelu:
self.activation_relu_or_gelu = 2
else:
self.activation_relu_or_gelu = 0
self.activation = activation
def __setstate__(self, state):
super(TransformerEncoderLayer, self).__setstate__(state)
if not hasattr(self, "activation"):
self.activation = F.relu
def forward(
self,
src: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
return_attention=False,
) -> Tensor:
r"""Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
see the docs in Transformer class.
"""
x = src
if self.norm_first:
attn, attn_weights = self._sa_block(
x=self.norm1(x),
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
return_attention=return_attention,
)
if return_attention:
return attn_weights
x = x + attn
x = x + self._ff_block(self.norm2(x))
else:
attn, attn_weights = self._sa_block(
x=self.norm1(x),
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
return_attention=return_attention,
)
if return_attention:
return attn_weights
x = self.norm1(x + attn)
x = self.norm2(x + self._ff_block(x))
return x
# self-attention block
def _sa_block(
self,
x: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],
return_attention: bool = False,
) -> Tensor:
x, attn_weights = self.self_attn(
x,
x,
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=return_attention,
average_attn_weights=False,
)
return self.dropout1(x), attn_weights
# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout2(x)
class TokenLearner(nn.Module):
"""Image to Patch Embedding"""
def __init__(self, img_size=224, patch_size=16, in_chans=1, embed_dim=768):
super().__init__()
num_patches = (img_size // patch_size) * (img_size // patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x)
x = x.flatten(2)
x = x.transpose(1, 2)
return x
class ChAdaViTModel(PreTrainedModel):
"""Channel Adaptive Vision Transformer"""
config_class = ChAdaViTConfig
def __init__(self, config):
super().__init__(config)
# Embeddings dimension
self.num_features = self.embed_dim = config.embed_dim
# Num of maximum channels in the batch
self.max_channels = config.max_number_channels
# Tokenization module
self.token_learner = TokenLearner(
img_size=config.img_size[0],
patch_size=config.patch_size,
in_chans=config.in_chans,
embed_dim=self.embed_dim,
)
num_patches = self.token_learner.num_patches
self.cls_token = nn.Parameter(
torch.zeros(1, 1, self.embed_dim)
) # (B, max_channels * num_tokens, embed_dim)
self.channel_token = nn.Parameter(
torch.zeros(1, self.max_channels, 1, self.embed_dim)
) # (B, max_channels, 1, embed_dim)
self.pos_embed = nn.Parameter(
torch.zeros(1, 1, num_patches + 1, self.embed_dim)
) # (B, max_channels, num_tokens, embed_dim)
self.pos_drop = nn.Dropout(p=config.drop_rate)
# TransformerEncoder block
dpr = [
x.item() for x in torch.linspace(0, config.drop_path_rate, config.depth)
] # stochastic depth decay rule
self.blocks = nn.ModuleList(
[
TransformerEncoderLayer(
d_model=self.embed_dim,
nhead=config.num_heads,
dim_feedforward=2048,
dropout=dpr[i],
batch_first=True,
)
for i in range(config.depth)
]
)
self.norm = nn.LayerNorm(self.embed_dim)
# Classifier head
self.head = nn.Linear(self.embed_dim, config.num_classes) if config.num_classes > 0 else nn.Identity()
# Return only the [CLS] token or all tokens
self.return_all_tokens = config.return_all_tokens
trunc_normal_(self.pos_embed, std=0.02)
trunc_normal_(self.cls_token, std=0.02)
trunc_normal_(self.channel_token, std=0.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def add_pos_encoding_per_channel(self, x, w, h, class_pos_embed: bool = False):
"""
Adds num_patches positional embeddings to EACH of the channels.
"""
npatch = x.shape[2]
N = self.pos_embed.shape[2] - 1
# --------------------- [CLS] positional encoding --------------------- #
if class_pos_embed:
return self.pos_embed[:, :, 0]
# --------------------- Patches positional encoding --------------------- #
# If the input size is the same as the training size, return the positional embeddings for the desired type
if npatch == N and w == h:
return self.pos_embed[:, :, 1:]
# Otherwise, interpolate the positional encoding for the input tokens
class_pos_embed = self.pos_embed[:, :, 0]
patch_pos_embed = self.pos_embed[:, :, 1:]
dim = x.shape[-1]
w0 = w // self.token_learner.patch_size
h0 = h // self.token_learner.patch_size
# a small number is added by DINO team to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
w0, h0 = w0 + 0.1, h0 + 0.1
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
mode="bicubic",
)
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return patch_pos_embed.unsqueeze(0)
def channel_aware_tokenization(self, x, index, list_num_channels, max_channels=10):
B, nc, w, h = x.shape # (B*num_channels, 1, w, h)
# Tokenize through linear embedding
tokens_per_channel = self.token_learner(x)
# Concatenate tokens per channel in each image
chunks = torch.split(tokens_per_channel, list_num_channels[index], dim=0)
# Pad the tokens tensor with zeros for each image separately in the chunks list
padded_tokens = [
torch.cat(
[
chunk,
torch.zeros(
(max_channels - chunk.size(0), chunk.size(1), chunk.size(2)),
device=chunk.device,
),
],
dim=0,
)
if chunk.size(0) < max_channels
else chunk
for chunk in chunks
]
# Stack along the batch dimension
padded_tokens = torch.stack(padded_tokens, dim=0)
num_tokens = padded_tokens.size(2)
# Reshape the patches embeddings on the channel dimension
padded_tokens = padded_tokens.reshape(padded_tokens.size(0), -1, padded_tokens.size(3))
# Compute the masking for avoiding self-attention on empty padded channels
channel_mask = torch.all(padded_tokens == 0.0, dim=-1)
# Destack to obtain the original number of channels
padded_tokens = padded_tokens.reshape(-1, max_channels, num_tokens, padded_tokens.size(-1))
# Add the [POS] token to the embed patch tokens
padded_tokens = padded_tokens + self.add_pos_encoding_per_channel(
padded_tokens, w, h, class_pos_embed=False
)
# Add the [CHANNEL] token to the embed patch tokens
if max_channels == self.max_channels:
channel_tokens = self.channel_token.expand(padded_tokens.shape[0], -1, padded_tokens.shape[2], -1)
padded_tokens = padded_tokens + channel_tokens
# Restack the patches embeddings on the channel dimension
embeddings = padded_tokens.reshape(padded_tokens.size(0), -1, padded_tokens.size(3))
# Expand the [CLS] token to the batch dimension
cls_tokens = self.cls_token.expand(embeddings.shape[0], -1, -1)
# Add [POS] positional encoding to the [CLS] token
cls_tokens = cls_tokens + self.add_pos_encoding_per_channel(embeddings, w, h, class_pos_embed=True)
# Concatenate the [CLS] token to the embed patch tokens
embeddings = torch.cat([cls_tokens, embeddings], dim=1)
# Adding a False value to the beginning of each channel_mask to account for the [CLS] token
channel_mask = torch.cat(
[
torch.tensor([False], device=channel_mask.device).expand(channel_mask.size(0), 1),
channel_mask,
],
dim=1,
)
return self.pos_drop(embeddings), channel_mask
def forward(self, x, index, list_num_channels):
# Apply the TokenLearner module to obtain learnable tokens
x, channel_mask = self.channel_aware_tokenization(
x, index, list_num_channels
) # (B*num_channels, embed_dim)
# Apply the self-attention layers with masked self-attention
for blk in self.blocks:
x = blk(
x, src_key_padding_mask=channel_mask
) # Use src_key_padding_mask to mask out padded tokens
# Normalize
x = self.norm(x)
if self.return_all_tokens:
# Create a mask to select non-masked tokens (excluding CLS token)
non_masked_tokens_mask = ~channel_mask[:, 1:]
non_masked_tokens = x[:, 1:][non_masked_tokens_mask]
return non_masked_tokens # return non-masked tokens (excluding CLS token)
else:
return x[:, 0] # return only the [CLS] token
def channel_token_sanity_check(self, x):
"""
Helper function to check consistency of channel tokens.
"""
# 1. Compare Patches Across Different Channels
print("Values for the first patch across different channels:")
for ch in range(10): # Assuming 10 channels
print(f"Channel {ch + 1}:", x[0, ch, 0, :5]) # Print first 5 values of the embedding for brevity
print("\n")
# 2. Compare Patches Within the Same Channel
for ch in range(10):
is_same = torch.all(x[0, ch, 0] == x[0, ch, 1])
print(f"First and second patch embeddings are the same for Channel {ch + 1}: {is_same.item()}")
# 3. Check Consistency Across Batch
print("Checking consistency of channel tokens across the batch:")
for ch in range(10):
is_consistent = torch.all(x[0, ch, 0] == x[1, ch, 0])
print(
f"Channel token for first patch is consistent between first and second image for Channel {ch + 1}: {is_consistent.item()}"
)
def get_last_selfattention(self, x):
x, channel_mask = self.channel_aware_tokenization(x, index=0, list_num_channels=[1], max_channels=1)
for i, blk in enumerate(self.blocks):
if i < len(self.blocks) - 1:
x = blk(x, src_key_padding_mask=channel_mask)
else:
# return attention of the last block
return blk(x, src_key_padding_mask=channel_mask, return_attention=True)
def get_intermediate_layers(self, x, n=1):
x, channel_mask = self.channel_aware_tokenization(x)
# return the output tokens from the `n` last blocks
output = []
for i, blk in enumerate(self.blocks):
x = blk(x, src_key_padding_mask=channel_mask)
if len(self.blocks) - i <= n:
output.append(self.norm(x))
return output