fix bug
Browse files- modeling_yi.py +1 -1
modeling_yi.py
CHANGED
@@ -539,7 +539,7 @@ class YiModel(YiPreTrainedModel):
|
|
539 |
def _prepare_decoder_attention_mask(
|
540 |
self, attention_mask, input_ids, inputs_embeds, past_key_values_length
|
541 |
):
|
542 |
-
input_shape = input_ids.shape if input_ids else inputs_embeds.shape[:-1]
|
543 |
# create causal mask
|
544 |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
545 |
combined_attention_mask = None
|
|
|
539 |
def _prepare_decoder_attention_mask(
|
540 |
self, attention_mask, input_ids, inputs_embeds, past_key_values_length
|
541 |
):
|
542 |
+
input_shape = input_ids.shape if input_ids is not None else inputs_embeds.shape[:-1]
|
543 |
# create causal mask
|
544 |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
545 |
combined_attention_mask = None
|