fix: crash because pos enc is on CPU device

#1
Files changed (1) hide show
  1. modeling_aimv2.py +1 -1
modeling_aimv2.py CHANGED
@@ -101,7 +101,7 @@ class AIMv2ViTPreprocessor(nn.Module):
101
  tokens = self.patchifier(x)
102
  pos_embed = get_sincos_pos_embed(
103
  H // self.patch_h, W // self.patch_w, embed_dim=self.embed_dim
104
- )
105
  tokens = tokens + pos_embed
106
  return tokens
107
 
 
101
  tokens = self.patchifier(x)
102
  pos_embed = get_sincos_pos_embed(
103
  H // self.patch_h, W // self.patch_w, embed_dim=self.embed_dim
104
+ ).to(tokens.device)
105
  tokens = tokens + pos_embed
106
  return tokens
107