request: Add flash attention 2.0 support for GPT2LMHeadModel
#75
by
brresnic
- opened
model = AutoModelForCausalLM.from_pretrained(
my_GPT2LMHeadModel_checkpoint,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
throws the following error:
Error loading Flash_Model_2: GPT2LMHeadModel does not support Flash Attention 2.0 yet. Please open an issue on GitHub to request support for this architecture: https://github.com/huggingface/transformers/issues/new
Hi
@brresnic
Thanks for your interest! There is an ongoing effort to add FA2 to GPT2 here: https://github.com/huggingface/transformers/pull/27479
Note however since the model size is relatively small I don't expect very interesting speedups with FA2 + gpt2