import torch.nn as nn import math import torch from IndicPhotoOCR.detection import east_config as cfg from IndicPhotoOCR.detection import east_utils as utils def conv_bn(inp, oup, stride): return nn.Sequential( nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), nn.ReLU6(inplace=True) ) class InvertedResidual(nn.Module): def __init__(self, inp, oup, stride, expand_ratio): super(InvertedResidual, self).__init__() self.stride = stride assert stride in [1, 2] hidden_dim = round(inp * expand_ratio) self.use_res_connect = self.stride == 1 and inp == oup if expand_ratio == 1: self.conv = nn.Sequential( # dw nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), nn.BatchNorm2d(hidden_dim), nn.ReLU6(inplace=True), # pw-linear nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), ) else: self.conv = nn.Sequential( # pw nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), nn.BatchNorm2d(hidden_dim), nn.ReLU6(inplace=True), # dw nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), nn.BatchNorm2d(hidden_dim), nn.ReLU6(inplace=True), # pw-linear nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), ) def forward(self, x): if self.use_res_connect: return x + self.conv(x) else: return self.conv(x) class MobileNetV2(nn.Module): def __init__(self, width_mult=1.): super(MobileNetV2, self).__init__() block = InvertedResidual input_channel = 32 last_channel = 1280 interverted_residual_setting = [ # t, c, n, s [1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 2], [6, 64, 4, 2], [6, 96, 3, 1], [6, 160, 3, 2], # [6, 320, 1, 1], ] # building first layer # assert input_size % 32 == 0 input_channel = int(input_channel * width_mult) self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel self.features = [conv_bn(3, input_channel, 2)] # building inverted residual blocks for t, c, n, s in interverted_residual_setting: output_channel = int(c * width_mult) for i in range(n): if i == 0: self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) else: self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) input_channel = output_channel # make it nn.Sequential self.features = nn.Sequential(*self.features) self._initialize_weights() def forward(self, x): x = self.features(x) # x = x.mean(3).mean(2) # x = self.classifier(x) return x def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.Linear): n = m.weight.size(1) m.weight.data.normal_(0, 0.01) m.bias.data.zero_() def mobilenet(pretrained=True, **kwargs): """ Constructs a ResNet-50 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = MobileNetV2() if pretrained: model_dict = model.state_dict() pretrained_dict = torch.load(cfg.pretrained_basemodel_path,map_location=torch.device('cpu'), weights_only=True) pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model_dict.update(pretrained_dict) model.load_state_dict(model_dict) # state_dict = torch.load(cfg.pretrained_basemodel_path) # add map_location='cpu' if no gpu # model.load_state_dict(state_dict) return model class East(nn.Module): def __init__(self): super(East, self).__init__() self.mobilenet = mobilenet(True) # self.si for stage i self.s1 = nn.Sequential(*list(self.mobilenet.children())[0][0:4]) self.s2 = nn.Sequential(*list(self.mobilenet.children())[0][4:7]) self.s3 = nn.Sequential(*list(self.mobilenet.children())[0][7:14]) self.s4 = nn.Sequential(*list(self.mobilenet.children())[0][14:17]) self.conv1 = nn.Conv2d(160+96, 128, 1) self.bn1 = nn.BatchNorm2d(128) self.relu1 = nn.ReLU() self.conv2 = nn.Conv2d(128, 128, 3, padding=1) self.bn2 = nn.BatchNorm2d(128) self.relu2 = nn.ReLU() self.conv3 = nn.Conv2d(128+32, 64, 1) self.bn3 = nn.BatchNorm2d(64) self.relu3 = nn.ReLU() self.conv4 = nn.Conv2d(64, 64, 3, padding=1) self.bn4 = nn.BatchNorm2d(64) self.relu4 = nn.ReLU() self.conv5 = nn.Conv2d(64+24, 64, 1) self.bn5 = nn.BatchNorm2d(64) self.relu5 = nn.ReLU() self.conv6 = nn.Conv2d(64, 32, 3, padding=1) self.bn6 = nn.BatchNorm2d(32) self.relu6 = nn.ReLU() self.conv7 = nn.Conv2d(32, 32, 3, padding=1) self.bn7 = nn.BatchNorm2d(32) self.relu7 = nn.ReLU() self.conv8 = nn.Conv2d(32, 1, 1) self.sigmoid1 = nn.Sigmoid() self.conv9 = nn.Conv2d(32, 4, 1) self.sigmoid2 = nn.Sigmoid() self.conv10 = nn.Conv2d(32, 1, 1) self.sigmoid3 = nn.Sigmoid() self.unpool1 = nn.Upsample(scale_factor=2, mode='bilinear') self.unpool2 = nn.Upsample(scale_factor=2, mode='bilinear') self.unpool3 = nn.Upsample(scale_factor=2, mode='bilinear') # utils.init_weights([self.conv1,self.conv2,self.conv3,self.conv4, # self.conv5,self.conv6,self.conv7,self.conv8, # self.conv9,self.conv10,self.bn1,self.bn2, # self.bn3,self.bn4,self.bn5,self.bn6,self.bn7]) def forward(self, images): images = utils.mean_image_subtraction(images) f0 = self.s1(images) f1 = self.s2(f0) f2 = self.s3(f1) f3 = self.s4(f2) # _, f = self.mobilenet(images) h = f3 # bs 2048 w/32 h/32 g = (self.unpool1(h)) # bs 2048 w/16 h/16 c = self.conv1(torch.cat((g, f2), 1)) c = self.bn1(c) c = self.relu1(c) h = self.conv2(c) # bs 128 w/16 h/16 h = self.bn2(h) h = self.relu2(h) g = self.unpool2(h) # bs 128 w/8 h/8 c = self.conv3(torch.cat((g, f1), 1)) c = self.bn3(c) c = self.relu3(c) h = self.conv4(c) # bs 64 w/8 h/8 h = self.bn4(h) h = self.relu4(h) g = self.unpool3(h) # bs 64 w/4 h/4 c = self.conv5(torch.cat((g, f0), 1)) c = self.bn5(c) c = self.relu5(c) h = self.conv6(c) # bs 32 w/4 h/4 h = self.bn6(h) h = self.relu6(h) g = self.conv7(h) # bs 32 w/4 h/4 g = self.bn7(g) g = self.relu7(g) F_score = self.conv8(g) # bs 1 w/4 h/4 F_score = self.sigmoid1(F_score) geo_map = self.conv9(g) geo_map = self.sigmoid2(geo_map) * 512 angle_map = self.conv10(g) angle_map = self.sigmoid3(angle_map) angle_map = (angle_map - 0.5) * math.pi / 2 F_geometry = torch.cat((geo_map, angle_map), 1) # bs 5 w/4 h/4 return F_score, F_geometry model=East()