XavierJiezou's picture
Update ktda/models/fmm.py
05c387d verified
from mmseg.registry import MODELS
from mmengine.model import BaseModule
from torch import nn as nn
from torch.nn import functional as F
from typing import Callable, Optional
from torch import Tensor
from timm.models.layers import trunc_normal_
from timm.models.vision_transformer import Block as TransformerBlock
class Mlp(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = nn.GELU,
drop: float = 0.0,
bias: bool = True,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
self.drop = nn.Dropout(drop)
def forward(self, x: Tensor) -> Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
@MODELS.register_module()
class FMM(BaseModule):
def __init__(
self,
in_channels,
rank_dim=4,
mlp_nums=1,
model_type="mlp",
num_heads=8,
mlp_ratio=4,
qkv_bias=True,
qk_norm=False,
init_values=None,
proj_drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
init_cfg=None,
):
super().__init__(init_cfg)
self.adapters = nn.ModuleList()
if model_type == "mlp":
for in_channel in in_channels:
mlp_list = []
for _ in range(mlp_nums):
mlp_list.append(
Mlp(
in_channel,
hidden_features=in_channel // rank_dim,
out_features=in_channel,
)
)
mlp_model = nn.Sequential(*mlp_list)
self.adapters.append(mlp_model)
elif model_type == "vitBlock":
for in_channel in in_channels:
model_list = []
for _ in range(mlp_nums):
model_list.append(
TransformerBlock(
in_channel,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
init_values=init_values,
proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate,
)
)
self.adapters.append(nn.Sequential(*model_list))
else:
raise ValueError(f"model type must in ['mlp','vitBlock'],actually is {model_type}")
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def forward(self, inputs):
outs = []
for index, x in enumerate(inputs):
B, C, H, W = x.shape
x = x.permute(0, 2, 3, 1)
x = x.reshape(B, -1, C)
x = self.adapters[index](x)
x = x.reshape(B, H, W, C)
x = x.permute(0, 3, 1, 2)
outs.append(x)
return tuple(outs)