Spaces:
Build error
Build error
File size: 1,746 Bytes
d61b9c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
from torch import nn
from .attention import PositionAttention, Attention
from .backbone import ResTranformer
from .model import Model
from .resnet import resnet45
class BaseVision(Model):
def __init__(self, dataset_max_length, null_label, num_classes,
attention='position', attention_mode='nearest', loss_weight=1.0,
d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu',
backbone='transformer', backbone_ln=2):
super().__init__(dataset_max_length, null_label)
self.loss_weight = loss_weight
self.out_channels = d_model
if backbone == 'transformer':
self.backbone = ResTranformer(d_model, nhead, d_inner, dropout, activation, backbone_ln)
else:
self.backbone = resnet45()
if attention == 'position':
self.attention = PositionAttention(
max_length=self.max_length,
mode=attention_mode
)
elif attention == 'attention':
self.attention = Attention(
max_length=self.max_length,
n_feature=8 * 32,
)
else:
raise ValueError(f'invalid attention: {attention}')
self.cls = nn.Linear(self.out_channels, num_classes)
def forward(self, images):
features = self.backbone(images) # (N, E, H, W)
attn_vecs, attn_scores = self.attention(features) # (N, T, E), (N, T, H, W)
logits = self.cls(attn_vecs) # (N, T, C)
pt_lengths = self._get_length(logits)
return {'feature': attn_vecs, 'logits': logits, 'pt_lengths': pt_lengths,
'attn_scores': attn_scores, 'loss_weight': self.loss_weight, 'name': 'vision'}
|