Fix `pos_embed` device
Browse files- modeling_aimv2.py +1 -1
modeling_aimv2.py
CHANGED
@@ -102,7 +102,7 @@ class AIMv2ViTPreprocessor(nn.Module):
|
|
102 |
pos_embed = get_sincos_pos_embed(
|
103 |
H // self.patch_h, W // self.patch_w, embed_dim=self.embed_dim
|
104 |
)
|
105 |
-
tokens = tokens + pos_embed
|
106 |
return tokens
|
107 |
|
108 |
|
|
|
102 |
pos_embed = get_sincos_pos_embed(
|
103 |
H // self.patch_h, W // self.patch_w, embed_dim=self.embed_dim
|
104 |
)
|
105 |
+
tokens = tokens + pos_embed.to(tokens.device)
|
106 |
return tokens
|
107 |
|
108 |
|