import torch.nn as nn import torch.utils.model_zoo as model_zoo import torchvision.models as models from IndicPhotoOCR.detection.textbpn.cfglib.config import config as cfg model_urls = { 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', } class VggNet(nn.Module): def __init__(self, name="vgg16", pretrain=True): super().__init__() if name == "vgg16": base_net = models.vgg16(pretrained=False) elif name == "vgg16_bn": base_net = models.vgg16_bn(pretrained=False) else: print(" base model is not support !") if pretrain: print("load the {} weight from ./cache".format(name)) base_net.load_state_dict(model_zoo.load_url(model_urls[name], model_dir="./cache",map_location=torch.device(cfg.device))) if name == "vgg16": self.stage1 = nn.Sequential(*[base_net.features[layer] for layer in range(0, 5)]) self.stage2 = nn.Sequential(*[base_net.features[layer] for layer in range(5, 10)]) self.stage3 = nn.Sequential(*[base_net.features[layer] for layer in range(10, 17)]) self.stage4 = nn.Sequential(*[base_net.features[layer] for layer in range(17, 24)]) self.stage5 = nn.Sequential(*[base_net.features[layer] for layer in range(24, 31)]) elif name == "vgg16_bn": self.stage1 = nn.Sequential(*[base_net.features[layer] for layer in range(0, 7)]) self.stage2 = nn.Sequential(*[base_net.features[layer] for layer in range(7, 14)]) self.stage3 = nn.Sequential(*[base_net.features[layer] for layer in range(14, 24)]) self.stage4 = nn.Sequential(*[base_net.features[layer] for layer in range(24, 34)]) self.stage5 = nn.Sequential(*[base_net.features[layer] for layer in range(34, 44)]) def forward(self, x): C1 = self.stage1(x) C2 = self.stage2(C1) C3 = self.stage3(C2) C4 = self.stage4(C3) C5 = self.stage5(C4) return C1, C2, C3, C4, C5 if __name__ == '__main__': import torch input = torch.randn((4, 3, 512, 512)) net = VggNet() C1, C2, C3, C4, C5 = net(input) print(C1.size()) print(C2.size()) print(C3.size()) print(C4.size()) print(C5.size())