x54-729 commited on
Commit
daa886c
1 Parent(s): 7bee5c5

update modeling file to newest

Browse files
Files changed (1) hide show
  1. modeling_internlm2.py +11 -3
modeling_internlm2.py CHANGED
@@ -13,7 +13,7 @@
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
  # See the License for the specific language governing permissions and
15
  # limitations under the License.
16
- """PyTorch InternLM2.5 model."""
17
  import math
18
  import queue
19
  import threading
@@ -59,6 +59,10 @@ try:
59
  except:
60
  pass
61
 
 
 
 
 
62
 
63
  logger = logging.get_logger(__name__)
64
 
@@ -1093,7 +1097,11 @@ class InternLM2Model(InternLM2PreTrainedModel):
1093
  else:
1094
  causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
1095
  if sequence_length != 1:
1096
- causal_mask = torch.triu(causal_mask, diagonal=1)
 
 
 
 
1097
  causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1098
  causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
1099
  if attention_mask is not None:
@@ -1797,4 +1805,4 @@ class InternLM2ForTokenClassification(InternLM2PreTrainedModel):
1797
  logits=logits,
1798
  hidden_states=outputs.hidden_states,
1799
  attentions=outputs.attentions,
1800
- )
 
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
  # See the License for the specific language governing permissions and
15
  # limitations under the License.
16
+ """PyTorch InternLM2 model."""
17
  import math
18
  import queue
19
  import threading
 
59
  except:
60
  pass
61
 
62
+ try:
63
+ support_bf16_triu = torch.__version__ >= "2.1.0"
64
+ except Exception:
65
+ support_bf16_triu = False
66
 
67
  logger = logging.get_logger(__name__)
68
 
 
1097
  else:
1098
  causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
1099
  if sequence_length != 1:
1100
+ if support_bf16_triu or dtype == torch.float32:
1101
+ causal_mask = torch.triu(causal_mask, diagonal=1)
1102
+ else:
1103
+ triu_mask = torch.triu(torch.ones(causal_mask.size(), device=device), diagonal=1).bool()
1104
+ causal_mask.masked_fill_(~triu_mask, 0)
1105
  causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1106
  causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
1107
  if attention_mask is not None:
 
1805
  logits=logits,
1806
  hidden_states=outputs.hidden_states,
1807
  attentions=outputs.attentions,
1808
+ )