flash_attn package makes it non-portable
#1
by
bghira
- opened
only runs on NVIDIA systems. not Apple, or AMD.
try to avoid :
model = AutoModelForCausalLM.from_pretrained(
EMU_HUB,
device_map="cuda:0",
torch_dtype=torch.bfloat16,
#attn_implementation="flash_attention_2",
trust_remote_code=True,
)