michalk8 commited on
Commit
c68d048
1 Parent(s): 0bfab3c

Fix `pos_embed` device

Browse files
Files changed (1) hide show
  1. modeling_aimv2.py +1 -1
modeling_aimv2.py CHANGED
@@ -102,7 +102,7 @@ class AIMv2ViTPreprocessor(nn.Module):
102
  pos_embed = get_sincos_pos_embed(
103
  H // self.patch_h, W // self.patch_w, embed_dim=self.embed_dim
104
  )
105
- tokens = tokens + pos_embed
106
  return tokens
107
 
108
 
 
102
  pos_embed = get_sincos_pos_embed(
103
  H // self.patch_h, W // self.patch_w, embed_dim=self.embed_dim
104
  )
105
+ tokens = tokens + pos_embed.to(tokens.device)
106
  return tokens
107
 
108