Feature Extraction
Transformers
Safetensors
ModularStarEncoder
custom_code
andreagurioli1995 commited on
Commit
a65d51f
·
verified ·
1 Parent(s): 739f9ab

Update modularStarEncoder.py

Browse files
Files changed (1) hide show
  1. modularStarEncoder.py +3 -1
modularStarEncoder.py CHANGED
@@ -205,11 +205,13 @@ def get_pooling_mask(
205
  repeated_idx = idx.unsqueeze(1).repeat(1, input_ids.size(1))
206
 
207
  DEVICE = input_ids.get_device()
 
 
208
  if DEVICE<0:
209
  DEVICE = "cpu"
210
  ranges = torch.arange(input_ids.size(1)).repeat(input_ids.size(0), 1)
211
 
212
-
213
  pooling_mask = (repeated_idx <= ranges).long()
214
  pooling_mask.to(DEVICE)
215
 
 
205
  repeated_idx = idx.unsqueeze(1).repeat(1, input_ids.size(1))
206
 
207
  DEVICE = input_ids.get_device()
208
+ print(DEVICE)
209
+
210
  if DEVICE<0:
211
  DEVICE = "cpu"
212
  ranges = torch.arange(input_ids.size(1)).repeat(input_ids.size(0), 1)
213
 
214
+ print(repeated_idx.get_device(),ranges.get_device())
215
  pooling_mask = (repeated_idx <= ranges).long()
216
  pooling_mask.to(DEVICE)
217