Spaces:
Running
Running
File size: 794 Bytes
29f689c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
from torch import nn
class NRTREncoder(nn.Module):
def __init__(self, in_channels):
super(NRTREncoder, self).__init__()
self.out_channels = 512 # 64*H
self.block = nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
out_channels=32,
kernel_size=3,
stride=2,
padding=1,
), nn.ReLU(), nn.BatchNorm2d(32),
nn.Conv2d(
in_channels=32,
out_channels=64,
kernel_size=3,
stride=2,
padding=1,
), nn.ReLU(), nn.BatchNorm2d(64))
def forward(self, images):
x = self.block(images)
x = x.permute(0, 3, 2, 1).flatten(2) # B, W, H*C
return x
|