Feature Extraction
Transformers
Safetensors
ModularStarEncoder
custom_code
andreagurioli1995 commited on
Commit
75b2f89
·
verified ·
1 Parent(s): a65d51f

Update modularStarEncoder.py

Browse files
Files changed (1) hide show
  1. modularStarEncoder.py +3 -3
modularStarEncoder.py CHANGED
@@ -205,15 +205,15 @@ def get_pooling_mask(
205
  repeated_idx = idx.unsqueeze(1).repeat(1, input_ids.size(1))
206
 
207
  DEVICE = input_ids.get_device()
208
- print(DEVICE)
209
 
210
  if DEVICE<0:
211
  DEVICE = "cpu"
212
  ranges = torch.arange(input_ids.size(1)).repeat(input_ids.size(0), 1)
 
 
213
 
214
- print(repeated_idx.get_device(),ranges.get_device())
215
  pooling_mask = (repeated_idx <= ranges).long()
216
- pooling_mask.to(DEVICE)
217
 
218
  return pooling_mask
219
 
 
205
  repeated_idx = idx.unsqueeze(1).repeat(1, input_ids.size(1))
206
 
207
  DEVICE = input_ids.get_device()
 
208
 
209
  if DEVICE<0:
210
  DEVICE = "cpu"
211
  ranges = torch.arange(input_ids.size(1)).repeat(input_ids.size(0), 1)
212
+ ranges.to(DEVICE)
213
+ repeated_idx.to(DEVICE)
214
 
 
215
  pooling_mask = (repeated_idx <= ranges).long()
216
+
217
 
218
  return pooling_mask
219