Spaces:
Running
Running
from torch import nn | |
class NRTREncoder(nn.Module): | |
def __init__(self, in_channels): | |
super(NRTREncoder, self).__init__() | |
self.out_channels = 512 # 64*H | |
self.block = nn.Sequential( | |
nn.Conv2d( | |
in_channels=in_channels, | |
out_channels=32, | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
), nn.ReLU(), nn.BatchNorm2d(32), | |
nn.Conv2d( | |
in_channels=32, | |
out_channels=64, | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
), nn.ReLU(), nn.BatchNorm2d(64)) | |
def forward(self, images): | |
x = self.block(images) | |
x = x.permute(0, 3, 2, 1).flatten(2) # B, W, H*C | |
return x | |