Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d | |
from model.deep_lab_model.aspp import build_aspp | |
from model.deep_lab_model.decoder import build_decoder | |
from model.deep_lab_model.backbone import build_backbone | |
class DeepLab(nn.Module): | |
def __init__(self, backbone='resnet', output_stride=16, num_classes=21, | |
sync_bn=True, freeze_bn=False): | |
super(DeepLab, self).__init__() | |
if backbone == 'drn': | |
output_stride = 8 | |
if sync_bn == True: | |
BatchNorm = SynchronizedBatchNorm2d | |
else: | |
BatchNorm = nn.BatchNorm2d | |
self.backbone = build_backbone(backbone, output_stride, BatchNorm) | |
self.aspp = build_aspp(backbone, output_stride, BatchNorm) | |
self.decoder = build_decoder(num_classes, backbone, BatchNorm) | |
self.freeze_bn = freeze_bn | |
def forward(self, input): | |
x, low_level_feat = self.backbone(input) | |
x = self.aspp(x) | |
x = self.decoder(x, low_level_feat) | |
x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True) | |
return x | |
def freeze_bn(self): | |
for m in self.modules(): | |
if isinstance(m, SynchronizedBatchNorm2d): | |
m.eval() | |
elif isinstance(m, nn.BatchNorm2d): | |
m.eval() | |
def get_1x_lr_params(self): | |
modules = [self.backbone] | |
for i in range(len(modules)): | |
for m in modules[i].named_modules(): | |
if self.freeze_bn: | |
if isinstance(m[1], nn.Conv2d): | |
for p in m[1].parameters(): | |
if p.requires_grad: | |
yield p | |
else: | |
if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ | |
or isinstance(m[1], nn.BatchNorm2d): | |
for p in m[1].parameters(): | |
if p.requires_grad: | |
yield p | |
def get_10x_lr_params(self): | |
modules = [self.aspp, self.decoder] | |
for i in range(len(modules)): | |
for m in modules[i].named_modules(): | |
if self.freeze_bn: | |
if isinstance(m[1], nn.Conv2d): | |
for p in m[1].parameters(): | |
if p.requires_grad: | |
yield p | |
else: | |
if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ | |
or isinstance(m[1], nn.BatchNorm2d): | |
for p in m[1].parameters(): | |
if p.requires_grad: | |
yield p | |
if __name__ == "__main__": | |
model = DeepLab(backbone='mobilenet', output_stride=16) | |
model.eval() | |
input = torch.rand(1, 3, 513, 513) | |
output = model(input) | |
print(output.size()) | |