vaibhavad commited on
Commit
b941c5a
1 Parent(s): 70be26b

Update attn_mask_utils.py

Browse files
Files changed (1) hide show
  1. attn_mask_utils.py +3 -2
attn_mask_utils.py CHANGED
@@ -38,8 +38,9 @@ def _prepare_4d_attention_mask_for_sdpa(
38
  elif query_length == 1:
39
  # For query_length == 1, causal attention and bi-directional attention are the same.
40
  attention_mask = None
41
- elif key_value_length == query_length:
42
- attention_mask = None
 
43
  else:
44
  # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
45
  # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
 
38
  elif query_length == 1:
39
  # For query_length == 1, causal attention and bi-directional attention are the same.
40
  attention_mask = None
41
+ # Commented out to deal with batch size=1 cases
42
+ # elif key_value_length == query_length:
43
+ # attention_mask = None
44
  else:
45
  # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
46
  # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.