sayakpaul HF Staff commited on
Commit
4a94046
·
verified ·
1 Parent(s): 167468e

Create qwen_fa3_processor.py

Browse files
Files changed (1) hide show
  1. qwenimage/qwen_fa3_processor.py +126 -0
qwenimage/qwen_fa3_processor.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ @torch.library.custom_op("flash::flash_attn_func", mutates_args=())
4
+ def flash_attn_func(
5
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool = False
6
+ ) -> torch.Tensor:
7
+ from flash_attn.flash_attn_interface import flash_attn_interface_func
8
+
9
+ outputs, lse = _flash_attn_func(q, k, v, causal=causal)
10
+ return outputs
11
+
12
+ @flash_attn_func.register_fake
13
+ def _(q, k, v, **kwargs):
14
+ # two outputs:
15
+ # 1. output: (batch, seq_len, num_heads, head_dim)
16
+ # 2. softmax_lse: (batch, num_heads, seq_len) with dtype=torch.float32
17
+ meta_q = torch.empty_like(q).contiguous()
18
+ return meta_q #, q.new_empty((q.size(0), q.size(2), q.size(1)), dtype=torch.float32)
19
+
20
+
21
+ class QwenDoubleStreamAttnProcessorFA3:
22
+ """
23
+ FA3-based attention processor for Qwen double-stream architecture.
24
+ Computes joint attention over concatenated [text, image] streams using vLLM FlashAttention-3
25
+ accessed via Hugging Face `kernels`.
26
+
27
+ Notes / limitations:
28
+ - General attention masks are not supported here (FA3 path). `is_causal=False` and no arbitrary mask.
29
+ - Optional windowed attention / sink tokens / softcap can be plumbed through if you use those features.
30
+ - Expects an available `apply_rotary_emb_qwen` in scope (same as your non-FA3 processor).
31
+ """
32
+
33
+ _attention_backend = "fa3" # for parity with your other processors, not used internally
34
+
35
+ def __init__(self):
36
+ try:
37
+ from flash_attn.flash_attn_interface import flash_attn_interface_func
38
+ except ImportError:
39
+ raise ImportError(
40
+ "flash_attention v3 package is required to be installed"
41
+ )
42
+
43
+ @torch.no_grad()
44
+ def __call__(
45
+ self,
46
+ attn, # Attention module with to_q/to_k/to_v/add_*_proj, norms, to_out, to_add_out, and .heads
47
+ hidden_states: torch.FloatTensor, # (B, S_img, D_model) image stream
48
+ encoder_hidden_states: torch.FloatTensor = None, # (B, S_txt, D_model) text stream
49
+ encoder_hidden_states_mask: torch.FloatTensor = None, # unused in FA3 path
50
+ attention_mask: Optional[torch.FloatTensor] = None, # unused in FA3 path
51
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # (img_freqs, txt_freqs)
52
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
53
+ if encoder_hidden_states is None:
54
+ raise ValueError("QwenDoubleStreamAttnProcessorFA3 requires encoder_hidden_states (text stream).")
55
+ if attention_mask is not None:
56
+ # FA3 kernel path here does not consume arbitrary masks; fail fast to avoid silent correctness issues.
57
+ raise NotImplementedError("attention_mask is not supported in this FA3 implementation.")
58
+
59
+ _ensure_fa3_available()
60
+
61
+ B, S_img, _ = hidden_states.shape
62
+ S_txt = encoder_hidden_states.shape[1]
63
+
64
+ # ---- QKV projections (image/sample stream) ----
65
+ img_q = attn.to_q(hidden_states) # (B, S_img, D)
66
+ img_k = attn.to_k(hidden_states)
67
+ img_v = attn.to_v(hidden_states)
68
+
69
+ # ---- QKV projections (text/context stream) ----
70
+ txt_q = attn.add_q_proj(encoder_hidden_states) # (B, S_txt, D)
71
+ txt_k = attn.add_k_proj(encoder_hidden_states)
72
+ txt_v = attn.add_v_proj(encoder_hidden_states)
73
+
74
+ # ---- Reshape to (B, S, H, D_h) ----
75
+ H = attn.heads
76
+ img_q = img_q.unflatten(-1, (H, -1))
77
+ img_k = img_k.unflatten(-1, (H, -1))
78
+ img_v = img_v.unflatten(-1, (H, -1))
79
+
80
+ txt_q = txt_q.unflatten(-1, (H, -1))
81
+ txt_k = txt_k.unflatten(-1, (H, -1))
82
+ txt_v = txt_v.unflatten(-1, (H, -1))
83
+
84
+ # ---- Q/K normalization (per your module contract) ----
85
+ if getattr(attn, "norm_q", None) is not None:
86
+ img_q = attn.norm_q(img_q)
87
+ if getattr(attn, "norm_k", None) is not None:
88
+ img_k = attn.norm_k(img_k)
89
+ if getattr(attn, "norm_added_q", None) is not None:
90
+ txt_q = attn.norm_added_q(txt_q)
91
+ if getattr(attn, "norm_added_k", None) is not None:
92
+ txt_k = attn.norm_added_k(txt_k)
93
+
94
+ # ---- RoPE (Qwen variant) ----
95
+ if image_rotary_emb is not None:
96
+ img_freqs, txt_freqs = image_rotary_emb
97
+ # expects tensors shaped (B, S, H, D_h)
98
+ img_q = apply_rotary_emb_qwen(img_q, img_freqs, use_real=False)
99
+ img_k = apply_rotary_emb_qwen(img_k, img_freqs, use_real=False)
100
+ txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs, use_real=False)
101
+ txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs, use_real=False)
102
+
103
+ # ---- Joint attention over [text, image] along sequence axis ----
104
+ # Shapes: (B, S_total, H, D_h)
105
+ q = torch.cat([txt_q, img_q], dim=1)
106
+ k = torch.cat([txt_k, img_k], dim=1)
107
+ v = torch.cat([txt_v, img_v], dim=1)
108
+
109
+ # FlashAttention-3 path expects (B, S, H, D_h) and returns (out, softmax_lse)
110
+ out = flash_attn_func(q, k, v, causal=False) # out: (B, S_total, H, D_h)
111
+
112
+ # ---- Back to (B, S, D_model) ----
113
+ out = out.flatten(2, 3).to(q.dtype)
114
+
115
+ # Split back to text / image segments
116
+ txt_attn_out = out[:, :S_txt, :]
117
+ img_attn_out = out[:, S_txt:, :]
118
+
119
+ # ---- Output projections ----
120
+ img_attn_out = attn.to_out[0](img_attn_out)
121
+ if len(attn.to_out) > 1:
122
+ img_attn_out = attn.to_out[1](img_attn_out) # dropout if present
123
+
124
+ txt_attn_out = attn.to_add_out(txt_attn_out)
125
+
126
+ return img_attn_out, txt_attn_out