Update modeling_bert.py
#1
by
bwang0911
- opened
- modeling_bert.py +1 -22
modeling_bert.py
CHANGED
@@ -675,11 +675,6 @@ class JinaBertEncoder(nn.Module):
|
|
675 |
)
|
676 |
self.gradient_checkpointing = False
|
677 |
self.num_attention_heads = config.num_attention_heads
|
678 |
-
self.register_buffer(
|
679 |
-
"alibi",
|
680 |
-
self.rebuild_alibi_tensor(size=config.max_position_embeddings),
|
681 |
-
persistent=False,
|
682 |
-
)
|
683 |
|
684 |
def rebuild_alibi_tensor(
|
685 |
self, size: int, device: Optional[Union[torch.device, str]] = None
|
@@ -747,23 +742,7 @@ class JinaBertEncoder(nn.Module):
|
|
747 |
|
748 |
# Add alibi matrix to extended_attention_mask
|
749 |
_, seqlen, _ = hidden_states.size()
|
750 |
-
|
751 |
-
# Rebuild the alibi tensor when needed
|
752 |
-
warnings.warn(
|
753 |
-
f'Increasing alibi size from {self._current_alibi_size} to {seqlen}.'
|
754 |
-
)
|
755 |
-
self.register_buffer(
|
756 |
-
"alibi",
|
757 |
-
self.rebuild_alibi_tensor(size=seqlen, device=hidden_states.device).to(
|
758 |
-
hidden_states.dtype
|
759 |
-
),
|
760 |
-
persistent=False,
|
761 |
-
)
|
762 |
-
elif self.alibi.device != hidden_states.device:
|
763 |
-
# Device catch-up
|
764 |
-
self.alibi = self.alibi.to(hidden_states.device)
|
765 |
-
|
766 |
-
alibi_bias = self.alibi[:, :, :seqlen, :seqlen]
|
767 |
if self.gradient_checkpointing and self.training:
|
768 |
if use_cache:
|
769 |
logger.warning_once(
|
|
|
675 |
)
|
676 |
self.gradient_checkpointing = False
|
677 |
self.num_attention_heads = config.num_attention_heads
|
|
|
|
|
|
|
|
|
|
|
678 |
|
679 |
def rebuild_alibi_tensor(
|
680 |
self, size: int, device: Optional[Union[torch.device, str]] = None
|
|
|
742 |
|
743 |
# Add alibi matrix to extended_attention_mask
|
744 |
_, seqlen, _ = hidden_states.size()
|
745 |
+
alibi_bias = self.rebuild_alibi_tensor(size=seqlen, device=hidden_states.device).to(hidden_states.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
746 |
if self.gradient_checkpointing and self.training:
|
747 |
if use_cache:
|
748 |
logger.warning_once(
|