import torch import torch.nn as nn import torch.nn.functional as F from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d from model.deep_lab_model.aspp import build_aspp from model.deep_lab_model.decoder import build_decoder from model.deep_lab_model.backbone import build_backbone class DeepLab(nn.Module): def __init__(self, backbone='resnet', output_stride=16, num_classes=21, sync_bn=True, freeze_bn=False): super(DeepLab, self).__init__() if backbone == 'drn': output_stride = 8 if sync_bn == True: BatchNorm = SynchronizedBatchNorm2d else: BatchNorm = nn.BatchNorm2d self.backbone = build_backbone(backbone, output_stride, BatchNorm) self.aspp = build_aspp(backbone, output_stride, BatchNorm) self.decoder = build_decoder(num_classes, backbone, BatchNorm) self.freeze_bn = freeze_bn def forward(self, input): x, low_level_feat = self.backbone(input) x = self.aspp(x) x = self.decoder(x, low_level_feat) x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True) return x def freeze_bn(self): for m in self.modules(): if isinstance(m, SynchronizedBatchNorm2d): m.eval() elif isinstance(m, nn.BatchNorm2d): m.eval() def get_1x_lr_params(self): modules = [self.backbone] for i in range(len(modules)): for m in modules[i].named_modules(): if self.freeze_bn: if isinstance(m[1], nn.Conv2d): for p in m[1].parameters(): if p.requires_grad: yield p else: if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ or isinstance(m[1], nn.BatchNorm2d): for p in m[1].parameters(): if p.requires_grad: yield p def get_10x_lr_params(self): modules = [self.aspp, self.decoder] for i in range(len(modules)): for m in modules[i].named_modules(): if self.freeze_bn: if isinstance(m[1], nn.Conv2d): for p in m[1].parameters(): if p.requires_grad: yield p else: if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ or isinstance(m[1], nn.BatchNorm2d): for p in m[1].parameters(): if p.requires_grad: yield p if __name__ == "__main__": model = DeepLab(backbone='mobilenet', output_stride=16) model.eval() input = torch.rand(1, 3, 513, 513) output = model(input) print(output.size())