XavierJiezou's picture
Update ktda/models/fam.py
7863f92 verified
from mmseg.registry import MODELS
from mmengine.model import BaseModule
from torch import nn as nn
from torch.nn import functional as F
from timm.models.layers import trunc_normal_
@MODELS.register_module()
class FAM(BaseModule):
def __init__(self, in_channels, out_channels, output_size,init_cfg=None):
super().__init__(init_cfg)
self.convert = nn.ModuleList()
self.output_size = output_size
if isinstance(out_channels, int):
out_channels = [out_channels] * len(in_channels)
for in_channel, out_channel in zip(in_channels, out_channels):
self.convert.append(
nn.Conv2d(in_channel, out_channel, kernel_size=1),
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
nn.init.constant_(m.bias, 0)
def forward(self, inputs):
outs = []
for index, x in enumerate(inputs):
x = self.convert[index](x)
x = F.interpolate(
x, size=(self.output_size,self.output_size), align_corners=False, mode="bilinear"
)
outs.append(x)
return tuple(outs)