Spaces:
Build error
Build error
from torch import nn | |
from .attention import PositionAttention, Attention | |
from .backbone import ResTranformer | |
from .model import Model | |
from .resnet import resnet45 | |
class BaseVision(Model): | |
def __init__(self, dataset_max_length, null_label, num_classes, | |
attention='position', attention_mode='nearest', loss_weight=1.0, | |
d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu', | |
backbone='transformer', backbone_ln=2): | |
super().__init__(dataset_max_length, null_label) | |
self.loss_weight = loss_weight | |
self.out_channels = d_model | |
if backbone == 'transformer': | |
self.backbone = ResTranformer(d_model, nhead, d_inner, dropout, activation, backbone_ln) | |
else: | |
self.backbone = resnet45() | |
if attention == 'position': | |
self.attention = PositionAttention( | |
max_length=self.max_length, | |
mode=attention_mode | |
) | |
elif attention == 'attention': | |
self.attention = Attention( | |
max_length=self.max_length, | |
n_feature=8 * 32, | |
) | |
else: | |
raise ValueError(f'invalid attention: {attention}') | |
self.cls = nn.Linear(self.out_channels, num_classes) | |
def forward(self, images): | |
features = self.backbone(images) # (N, E, H, W) | |
attn_vecs, attn_scores = self.attention(features) # (N, T, E), (N, T, H, W) | |
logits = self.cls(attn_vecs) # (N, T, C) | |
pt_lengths = self._get_length(logits) | |
return {'feature': attn_vecs, 'logits': logits, 'pt_lengths': pt_lengths, | |
'attn_scores': attn_scores, 'loss_weight': self.loss_weight, 'name': 'vision'} | |