sayakpaul HF Staff commited on
Commit
39886f2
·
verified ·
1 Parent(s): de4cfc8

Update qwenimage/qwen_fa3_processor.py

Browse files
Files changed (1) hide show
  1. qwenimage/qwen_fa3_processor.py +4 -3
qwenimage/qwen_fa3_processor.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  import torch
2
  from typing import Optional, Tuple
3
  from diffusers.models.transformers.transformer_qwenimage import apply_rotary_emb_qwen
@@ -18,14 +22,11 @@ def _ensure_fa3_available():
18
  "Tried `get_kernel('kernels-community/vllm-flash-attn3')` and failed with:\n"
19
  f"{_kernels_err}"
20
  )
21
-
22
 
23
  @torch.library.custom_op("flash::flash_attn_func", mutates_args=())
24
  def flash_attn_func(
25
  q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool = False
26
  ) -> torch.Tensor:
27
- from flash_attn.flash_attn_interface import flash_attn_interface_func
28
-
29
  outputs, lse = _flash_attn_func(q, k, v, causal=causal)
30
  return outputs
31
 
 
1
+ """
2
+ Paired with a good language model. Thanks!
3
+ """
4
+
5
  import torch
6
  from typing import Optional, Tuple
7
  from diffusers.models.transformers.transformer_qwenimage import apply_rotary_emb_qwen
 
22
  "Tried `get_kernel('kernels-community/vllm-flash-attn3')` and failed with:\n"
23
  f"{_kernels_err}"
24
  )
 
25
 
26
  @torch.library.custom_op("flash::flash_attn_func", mutates_args=())
27
  def flash_attn_func(
28
  q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool = False
29
  ) -> torch.Tensor:
 
 
30
  outputs, lse = _flash_attn_func(q, k, v, causal=causal)
31
  return outputs
32