FlashAttention2 support not working during training

#1
by Qubitium - opened

@v2ray First of all huge thanks for the first working model that I can actually full-finetune with bfloat16. However the vram usage is insane due to flash attention2 in-compatibility. Are you able to train (bfloat16) with flash attention 2 enabled on your end?

@Qubitium Yes, I am able to use flash attention 2, I'm not doing full fine-tune tho, it's a LoRA tune I tested.
https://github.com/LagPixelLOL/qlora/blob/main/scripts/finetune_schizogpt_132b.sh
This is the script I used to test, with eval disabled for DeepSpeed to work.
https://huggingface.co/v2ray/SchizoGPT-132B-QLoRA
This is the result of the training run.
Not like the name suggested, it's actually just a regular LoRA instead of QLoRA because I set the bits to 16. I trained it on 8x A100 80GB.
All the libraries I used are at the latest release version(Not the dev version), CUDA version I used is 12.2.

Hey v2ray,
thank you for the conversion.
I'm using TRL for finetuning and I'm getting stuck on the target_modules for PEFT, in the repo you forwarded there's a function to extract all linear layers but I get an error
Which modules did you use?

@v2ray The Fa2 bug I encountered was caused by my own custom training code. Sorry for the false alarm. =P

Qubitium changed discussion status to closed
Owner

@ChristianPalaArtificialy Hello, I'm using:

  "target_modules": [
    "v1",
    "Wqkv",
    "layer",
    "out_proj",
    "w1",
    "w2"
  ],

Also what's the error you were getting?

Sign up or log in to comment