Spaces:
Build error
Build error
File size: 665 Bytes
d61b9c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
"""
Implementation of ViTSTR based on timm VisionTransformer.
TODO:
1) distilled deit backbone
2) base deit backbone
Copyright 2021 Rowel Atienza
"""
from timm.models.vision_transformer import VisionTransformer
class ViTSTR(VisionTransformer):
"""
ViTSTR is basically a ViT that uses DeiT weights.
Modified head to support a sequence of characters prediction for STR.
"""
def forward(self, x, seqlen: int = 25):
x = self.forward_features(x)
x = x[:, :seqlen]
# batch, seqlen, embsize
b, s, e = x.size()
x = x.reshape(b * s, e)
x = self.head(x).view(b, s, self.num_classes)
return x
|