Do you really use flash attention?
#5
by
GinnM
- opened
I noticed that:
attn = scaled_dot_product_attention(
query=xq.transpose(1, 2),
key=xk.transpose(1, 2),
value=xv.transpose(1, 2),
attn_mask=attention_mask.bool(),
dropout_p=0,
).transpose(1, 2)
But in the scenario that the attn_mask parameter is not None, scaled_dot_product_attention will not use flash attention actually.