File size: 3,030 Bytes
c509e76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
from model.deep_lab_model.aspp import build_aspp
from model.deep_lab_model.decoder import build_decoder
from model.deep_lab_model.backbone import build_backbone
class DeepLab(nn.Module):
def __init__(self, backbone='resnet', output_stride=16, num_classes=21,
sync_bn=True, freeze_bn=False):
super(DeepLab, self).__init__()
if backbone == 'drn':
output_stride = 8
if sync_bn == True:
BatchNorm = SynchronizedBatchNorm2d
else:
BatchNorm = nn.BatchNorm2d
self.backbone = build_backbone(backbone, output_stride, BatchNorm)
self.aspp = build_aspp(backbone, output_stride, BatchNorm)
self.decoder = build_decoder(num_classes, backbone, BatchNorm)
self.freeze_bn = freeze_bn
def forward(self, input):
x, low_level_feat = self.backbone(input)
x = self.aspp(x)
x = self.decoder(x, low_level_feat)
x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True)
return x
def freeze_bn(self):
for m in self.modules():
if isinstance(m, SynchronizedBatchNorm2d):
m.eval()
elif isinstance(m, nn.BatchNorm2d):
m.eval()
def get_1x_lr_params(self):
modules = [self.backbone]
for i in range(len(modules)):
for m in modules[i].named_modules():
if self.freeze_bn:
if isinstance(m[1], nn.Conv2d):
for p in m[1].parameters():
if p.requires_grad:
yield p
else:
if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
or isinstance(m[1], nn.BatchNorm2d):
for p in m[1].parameters():
if p.requires_grad:
yield p
def get_10x_lr_params(self):
modules = [self.aspp, self.decoder]
for i in range(len(modules)):
for m in modules[i].named_modules():
if self.freeze_bn:
if isinstance(m[1], nn.Conv2d):
for p in m[1].parameters():
if p.requires_grad:
yield p
else:
if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
or isinstance(m[1], nn.BatchNorm2d):
for p in m[1].parameters():
if p.requires_grad:
yield p
if __name__ == "__main__":
model = DeepLab(backbone='mobilenet', output_stride=16)
model.eval()
input = torch.rand(1, 3, 513, 513)
output = model(input)
print(output.size())
|