ltg
/

davda54 commited on
Commit
ca3c03c
1 Parent(s): 319aff6

Update modeling_deberta.py

Browse files
Files changed (1) hide show
  1. modeling_deberta.py +4 -10
modeling_deberta.py CHANGED
@@ -14,6 +14,7 @@
14
  # limitations under the License.
15
  """ PyTorch DeBERTa-v2 model."""
16
 
 
17
  from collections.abc import Sequence
18
  from typing import Optional, Tuple, Union
19
 
@@ -553,16 +554,9 @@ class DebertaV2Encoder(nn.Module):
553
  def make_log_bucket_position(relative_pos, bucket_size, max_position):
554
  sign = torch.sign(relative_pos)
555
  mid = bucket_size // 2
556
- abs_pos = torch.where(
557
- (relative_pos < mid) & (relative_pos > -mid),
558
- torch.tensor(mid - 1).type_as(relative_pos),
559
- torch.abs(relative_pos),
560
- )
561
- log_pos = (
562
- torch.ceil(torch.log(abs_pos / mid) / torch.log(torch.tensor((max_position - 1) / mid)) * (mid - 1)) + mid
563
- )
564
- bucket_pos = torch.where(abs_pos <= mid, relative_pos.type_as(log_pos), log_pos * sign)
565
- bucket_pos = bucket_pos.clamp(min=-bucket_size+1, max=bucket_size-1)
566
  return bucket_pos
567
 
568
 
 
14
  # limitations under the License.
15
  """ PyTorch DeBERTa-v2 model."""
16
 
17
+ import math
18
  from collections.abc import Sequence
19
  from typing import Optional, Tuple, Union
20
 
 
554
  def make_log_bucket_position(relative_pos, bucket_size, max_position):
555
  sign = torch.sign(relative_pos)
556
  mid = bucket_size // 2
557
+ abs_pos = torch.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, torch.abs(relative_pos).clamp(max=max_position - 1))
558
+ log_pos = torch.ceil(torch.log(abs_pos / mid) / math.log((max_position-1) / mid) * (mid - 1)).int() + mid
559
+ bucket_pos = torch.where(abs_pos <= mid, relative_pos, log_pos * sign).long()
 
 
 
 
 
 
 
560
  return bucket_pos
561
 
562