Update enhance.py
Browse files- enhance.py +14 -9
enhance.py
CHANGED
@@ -12,8 +12,10 @@ class LTXEnhanceAttnProcessor2_0:
|
|
12 |
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
13 |
raise ImportError("LTXEnhanceAttnProcessor2_0 requires PyTorch 2.0.")
|
14 |
|
15 |
-
def _get_enhance_scores(self, query, key,
|
16 |
"""Calculate enhancement scores for the attention mechanism"""
|
|
|
|
|
17 |
if text_seq_length is not None:
|
18 |
img_q = query[:, :, :-text_seq_length] if text_seq_length > 0 else query
|
19 |
img_k = key[:, :, :-text_seq_length] if text_seq_length > 0 else key
|
@@ -48,15 +50,17 @@ class LTXEnhanceAttnProcessor2_0:
|
|
48 |
if encoder_hidden_states is None:
|
49 |
encoder_hidden_states = hidden_states
|
50 |
|
51 |
-
|
|
|
|
|
52 |
|
53 |
query = attn.to_q(hidden_states)
|
54 |
key = attn.to_k(encoder_hidden_states)
|
55 |
value = attn.to_v(encoder_hidden_states)
|
56 |
|
57 |
-
query = query.view(batch_size, -1,
|
58 |
-
key = key.view(batch_size, -1,
|
59 |
-
value = value.view(batch_size, -1,
|
60 |
|
61 |
if attn.upcast_attention:
|
62 |
query = query.float()
|
@@ -65,8 +69,9 @@ class LTXEnhanceAttnProcessor2_0:
|
|
65 |
enhance_scores = None
|
66 |
if is_enhance_enabled():
|
67 |
enhance_scores = self._get_enhance_scores(
|
68 |
-
query, key,
|
69 |
-
|
|
|
70 |
get_num_frames(),
|
71 |
text_seq_length
|
72 |
)
|
@@ -78,7 +83,7 @@ class LTXEnhanceAttnProcessor2_0:
|
|
78 |
is_causal=False
|
79 |
)
|
80 |
|
81 |
-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1,
|
82 |
hidden_states = hidden_states.to(query.dtype)
|
83 |
|
84 |
# Apply enhancement if enabled
|
@@ -112,4 +117,4 @@ def num_frames_hook(module, args, kwargs):
|
|
112 |
hidden_states = args[0]
|
113 |
num_frames = hidden_states.shape[2]
|
114 |
set_num_frames(num_frames)
|
115 |
-
return args, kwargs
|
|
|
12 |
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
13 |
raise ImportError("LTXEnhanceAttnProcessor2_0 requires PyTorch 2.0.")
|
14 |
|
15 |
+
def _get_enhance_scores(self, query, key, inner_dim, num_heads, num_frames, text_seq_length=None):
|
16 |
"""Calculate enhancement scores for the attention mechanism"""
|
17 |
+
head_dim = inner_dim // num_heads
|
18 |
+
|
19 |
if text_seq_length is not None:
|
20 |
img_q = query[:, :, :-text_seq_length] if text_seq_length > 0 else query
|
21 |
img_k = key[:, :, :-text_seq_length] if text_seq_length > 0 else key
|
|
|
50 |
if encoder_hidden_states is None:
|
51 |
encoder_hidden_states = hidden_states
|
52 |
|
53 |
+
inner_dim = attn.to_q.out_features
|
54 |
+
num_heads = attn.heads
|
55 |
+
head_dim = inner_dim // num_heads
|
56 |
|
57 |
query = attn.to_q(hidden_states)
|
58 |
key = attn.to_k(encoder_hidden_states)
|
59 |
value = attn.to_v(encoder_hidden_states)
|
60 |
|
61 |
+
query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
62 |
+
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
63 |
+
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
64 |
|
65 |
if attn.upcast_attention:
|
66 |
query = query.float()
|
|
|
69 |
enhance_scores = None
|
70 |
if is_enhance_enabled():
|
71 |
enhance_scores = self._get_enhance_scores(
|
72 |
+
query, key,
|
73 |
+
inner_dim,
|
74 |
+
num_heads,
|
75 |
get_num_frames(),
|
76 |
text_seq_length
|
77 |
)
|
|
|
83 |
is_causal=False
|
84 |
)
|
85 |
|
86 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, inner_dim)
|
87 |
hidden_states = hidden_states.to(query.dtype)
|
88 |
|
89 |
# Apply enhancement if enabled
|
|
|
117 |
hidden_states = args[0]
|
118 |
num_frames = hidden_states.shape[2]
|
119 |
set_num_frames(num_frames)
|
120 |
+
return args, kwargs
|