Spaces:
Build error
Build error
import torch.nn as nn | |
from strhub.models.modules import BidirectionalLSTM | |
class CRNN(nn.Module): | |
def __init__(self, img_h, nc, nclass, nh, leaky_relu=False): | |
super().__init__() | |
assert img_h % 16 == 0, 'img_h has to be a multiple of 16' | |
ks = [3, 3, 3, 3, 3, 3, 2] | |
ps = [1, 1, 1, 1, 1, 1, 0] | |
ss = [1, 1, 1, 1, 1, 1, 1] | |
nm = [64, 128, 256, 256, 512, 512, 512] | |
cnn = nn.Sequential() | |
def convRelu(i, batchNormalization=False): | |
nIn = nc if i == 0 else nm[i - 1] | |
nOut = nm[i] | |
cnn.add_module('conv{0}'.format(i), | |
nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i], bias=not batchNormalization)) | |
if batchNormalization: | |
cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut)) | |
if leaky_relu: | |
cnn.add_module('relu{0}'.format(i), | |
nn.LeakyReLU(0.2, inplace=True)) | |
else: | |
cnn.add_module('relu{0}'.format(i), nn.ReLU(True)) | |
convRelu(0) | |
cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64 | |
convRelu(1) | |
cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32 | |
convRelu(2, True) | |
convRelu(3) | |
cnn.add_module('pooling{0}'.format(2), | |
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16 | |
convRelu(4, True) | |
convRelu(5) | |
cnn.add_module('pooling{0}'.format(3), | |
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16 | |
convRelu(6, True) # 512x1x16 | |
self.cnn = cnn | |
self.rnn = nn.Sequential( | |
BidirectionalLSTM(512, nh, nh), | |
BidirectionalLSTM(nh, nh, nclass)) | |
def forward(self, input): | |
# conv features | |
conv = self.cnn(input) | |
b, c, h, w = conv.size() | |
assert h == 1, 'the height of conv must be 1' | |
conv = conv.squeeze(2) | |
conv = conv.transpose(1, 2) # [b, w, c] | |
# rnn features | |
output = self.rnn(conv) | |
return output | |