Feature Extraction
Transformers
Safetensors
ModularStarEncoder
custom_code
andreagurioli1995 commited on
Commit
a9041e6
·
verified ·
1 Parent(s): f8be806

Update modularStarEncoder.py

Browse files
Files changed (1) hide show
  1. modularStarEncoder.py +5 -0
modularStarEncoder.py CHANGED
@@ -204,7 +204,12 @@ def get_pooling_mask(
204
 
205
  repeated_idx = idx.unsqueeze(1).repeat(1, input_ids.size(1))
206
 
 
 
 
 
207
  ranges = torch.arange(input_ids.size(1)).repeat(input_ids.size(0), 1)
 
208
 
209
  pooling_mask = (repeated_idx <= ranges).long()
210
 
 
204
 
205
  repeated_idx = idx.unsqueeze(1).repeat(1, input_ids.size(1))
206
 
207
+ DEVICE = input_ids.get_device()
208
+
209
+ if DEVICE<0:
210
+ DEVICE = "cpu"
211
  ranges = torch.arange(input_ids.size(1)).repeat(input_ids.size(0), 1)
212
+ ranges = ranges.to(DEVICE)
213
 
214
  pooling_mask = (repeated_idx <= ranges).long()
215