Spaces:
Build error
Build error
File size: 1,032 Bytes
d61b9c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
import torch.nn as nn
from torch.nn import TransformerEncoderLayer, TransformerEncoder
from .resnet import resnet45
from .transformer import PositionalEncoding
class ResTranformer(nn.Module):
def __init__(self, d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu', backbone_ln=2):
super().__init__()
self.resnet = resnet45()
self.pos_encoder = PositionalEncoding(d_model, max_len=8 * 32)
encoder_layer = TransformerEncoderLayer(d_model=d_model, nhead=nhead,
dim_feedforward=d_inner, dropout=dropout, activation=activation)
self.transformer = TransformerEncoder(encoder_layer, backbone_ln)
def forward(self, images):
feature = self.resnet(images)
n, c, h, w = feature.shape
feature = feature.view(n, c, -1).permute(2, 0, 1)
feature = self.pos_encoder(feature)
feature = self.transformer(feature)
feature = feature.permute(1, 2, 0).view(n, c, h, w)
return feature
|