Update modularStarEncoder.py
Browse files- modularStarEncoder.py +5 -0
modularStarEncoder.py
CHANGED
@@ -204,7 +204,12 @@ def get_pooling_mask(
|
|
204 |
|
205 |
repeated_idx = idx.unsqueeze(1).repeat(1, input_ids.size(1))
|
206 |
|
|
|
|
|
|
|
|
|
207 |
ranges = torch.arange(input_ids.size(1)).repeat(input_ids.size(0), 1)
|
|
|
208 |
|
209 |
pooling_mask = (repeated_idx <= ranges).long()
|
210 |
|
|
|
204 |
|
205 |
repeated_idx = idx.unsqueeze(1).repeat(1, input_ids.size(1))
|
206 |
|
207 |
+
DEVICE = input_ids.get_device()
|
208 |
+
|
209 |
+
if DEVICE<0:
|
210 |
+
DEVICE = "cpu"
|
211 |
ranges = torch.arange(input_ids.size(1)).repeat(input_ids.size(0), 1)
|
212 |
+
ranges = ranges.to(DEVICE)
|
213 |
|
214 |
pooling_mask = (repeated_idx <= ranges).long()
|
215 |
|