strexp / strhub /models /abinet /backbone.py
markytools's picture
added strexp
d61b9c7
raw
history blame
1.03 kB
import torch.nn as nn
from torch.nn import TransformerEncoderLayer, TransformerEncoder
from .resnet import resnet45
from .transformer import PositionalEncoding
class ResTranformer(nn.Module):
def __init__(self, d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu', backbone_ln=2):
super().__init__()
self.resnet = resnet45()
self.pos_encoder = PositionalEncoding(d_model, max_len=8 * 32)
encoder_layer = TransformerEncoderLayer(d_model=d_model, nhead=nhead,
dim_feedforward=d_inner, dropout=dropout, activation=activation)
self.transformer = TransformerEncoder(encoder_layer, backbone_ln)
def forward(self, images):
feature = self.resnet(images)
n, c, h, w = feature.shape
feature = feature.view(n, c, -1).permute(2, 0, 1)
feature = self.pos_encoder(feature)
feature = self.transformer(feature)
feature = feature.permute(1, 2, 0).view(n, c, h, w)
return feature