Spaces:
Running
on
Zero
Running
on
Zero
Update qwenimage/qwen_fa3_processor.py
Browse files
qwenimage/qwen_fa3_processor.py
CHANGED
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
from typing import Optional, Tuple
|
3 |
from diffusers.models.transformers.transformer_qwenimage import apply_rotary_emb_qwen
|
@@ -18,14 +22,11 @@ def _ensure_fa3_available():
|
|
18 |
"Tried `get_kernel('kernels-community/vllm-flash-attn3')` and failed with:\n"
|
19 |
f"{_kernels_err}"
|
20 |
)
|
21 |
-
|
22 |
|
23 |
@torch.library.custom_op("flash::flash_attn_func", mutates_args=())
|
24 |
def flash_attn_func(
|
25 |
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool = False
|
26 |
) -> torch.Tensor:
|
27 |
-
from flash_attn.flash_attn_interface import flash_attn_interface_func
|
28 |
-
|
29 |
outputs, lse = _flash_attn_func(q, k, v, causal=causal)
|
30 |
return outputs
|
31 |
|
|
|
1 |
+
"""
|
2 |
+
Paired with a good language model. Thanks!
|
3 |
+
"""
|
4 |
+
|
5 |
import torch
|
6 |
from typing import Optional, Tuple
|
7 |
from diffusers.models.transformers.transformer_qwenimage import apply_rotary_emb_qwen
|
|
|
22 |
"Tried `get_kernel('kernels-community/vllm-flash-attn3')` and failed with:\n"
|
23 |
f"{_kernels_err}"
|
24 |
)
|
|
|
25 |
|
26 |
@torch.library.custom_op("flash::flash_attn_func", mutates_args=())
|
27 |
def flash_attn_func(
|
28 |
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool = False
|
29 |
) -> torch.Tensor:
|
|
|
|
|
30 |
outputs, lse = _flash_attn_func(q, k, v, causal=causal)
|
31 |
return outputs
|
32 |
|