strexp / strhub /models /abinet /model_vision.py
markytools's picture
added strexp
d61b9c7
raw
history blame
1.75 kB
from torch import nn
from .attention import PositionAttention, Attention
from .backbone import ResTranformer
from .model import Model
from .resnet import resnet45
class BaseVision(Model):
def __init__(self, dataset_max_length, null_label, num_classes,
attention='position', attention_mode='nearest', loss_weight=1.0,
d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu',
backbone='transformer', backbone_ln=2):
super().__init__(dataset_max_length, null_label)
self.loss_weight = loss_weight
self.out_channels = d_model
if backbone == 'transformer':
self.backbone = ResTranformer(d_model, nhead, d_inner, dropout, activation, backbone_ln)
else:
self.backbone = resnet45()
if attention == 'position':
self.attention = PositionAttention(
max_length=self.max_length,
mode=attention_mode
)
elif attention == 'attention':
self.attention = Attention(
max_length=self.max_length,
n_feature=8 * 32,
)
else:
raise ValueError(f'invalid attention: {attention}')
self.cls = nn.Linear(self.out_channels, num_classes)
def forward(self, images):
features = self.backbone(images) # (N, E, H, W)
attn_vecs, attn_scores = self.attention(features) # (N, T, E), (N, T, H, W)
logits = self.cls(attn_vecs) # (N, T, C)
pt_lengths = self._get_length(logits)
return {'feature': attn_vecs, 'logits': logits, 'pt_lengths': pt_lengths,
'attn_scores': attn_scores, 'loss_weight': self.loss_weight, 'name': 'vision'}