jbilcke-hf HF staff commited on
Commit
1e487c7
·
verified ·
1 Parent(s): fcbe761

Update enhance.py

Browse files
Files changed (1) hide show
  1. enhance.py +14 -9
enhance.py CHANGED
@@ -12,8 +12,10 @@ class LTXEnhanceAttnProcessor2_0:
12
  if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
13
  raise ImportError("LTXEnhanceAttnProcessor2_0 requires PyTorch 2.0.")
14
 
15
- def _get_enhance_scores(self, query, key, head_dim, num_frames, text_seq_length=None):
16
  """Calculate enhancement scores for the attention mechanism"""
 
 
17
  if text_seq_length is not None:
18
  img_q = query[:, :, :-text_seq_length] if text_seq_length > 0 else query
19
  img_k = key[:, :, :-text_seq_length] if text_seq_length > 0 else key
@@ -48,15 +50,17 @@ class LTXEnhanceAttnProcessor2_0:
48
  if encoder_hidden_states is None:
49
  encoder_hidden_states = hidden_states
50
 
51
- head_dim = attn.heads
 
 
52
 
53
  query = attn.to_q(hidden_states)
54
  key = attn.to_k(encoder_hidden_states)
55
  value = attn.to_v(encoder_hidden_states)
56
 
57
- query = query.view(batch_size, -1, attn.heads, attn.head_dim).transpose(1, 2)
58
- key = key.view(batch_size, -1, attn.heads, attn.head_dim).transpose(1, 2)
59
- value = value.view(batch_size, -1, attn.heads, attn.head_dim).transpose(1, 2)
60
 
61
  if attn.upcast_attention:
62
  query = query.float()
@@ -65,8 +69,9 @@ class LTXEnhanceAttnProcessor2_0:
65
  enhance_scores = None
66
  if is_enhance_enabled():
67
  enhance_scores = self._get_enhance_scores(
68
- query, key,
69
- attn.head_dim,
 
70
  get_num_frames(),
71
  text_seq_length
72
  )
@@ -78,7 +83,7 @@ class LTXEnhanceAttnProcessor2_0:
78
  is_causal=False
79
  )
80
 
81
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * attn.head_dim)
82
  hidden_states = hidden_states.to(query.dtype)
83
 
84
  # Apply enhancement if enabled
@@ -112,4 +117,4 @@ def num_frames_hook(module, args, kwargs):
112
  hidden_states = args[0]
113
  num_frames = hidden_states.shape[2]
114
  set_num_frames(num_frames)
115
- return args, kwargs
 
12
  if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
13
  raise ImportError("LTXEnhanceAttnProcessor2_0 requires PyTorch 2.0.")
14
 
15
+ def _get_enhance_scores(self, query, key, inner_dim, num_heads, num_frames, text_seq_length=None):
16
  """Calculate enhancement scores for the attention mechanism"""
17
+ head_dim = inner_dim // num_heads
18
+
19
  if text_seq_length is not None:
20
  img_q = query[:, :, :-text_seq_length] if text_seq_length > 0 else query
21
  img_k = key[:, :, :-text_seq_length] if text_seq_length > 0 else key
 
50
  if encoder_hidden_states is None:
51
  encoder_hidden_states = hidden_states
52
 
53
+ inner_dim = attn.to_q.out_features
54
+ num_heads = attn.heads
55
+ head_dim = inner_dim // num_heads
56
 
57
  query = attn.to_q(hidden_states)
58
  key = attn.to_k(encoder_hidden_states)
59
  value = attn.to_v(encoder_hidden_states)
60
 
61
+ query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
62
+ key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
63
+ value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
64
 
65
  if attn.upcast_attention:
66
  query = query.float()
 
69
  enhance_scores = None
70
  if is_enhance_enabled():
71
  enhance_scores = self._get_enhance_scores(
72
+ query, key,
73
+ inner_dim,
74
+ num_heads,
75
  get_num_frames(),
76
  text_seq_length
77
  )
 
83
  is_causal=False
84
  )
85
 
86
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, inner_dim)
87
  hidden_states = hidden_states.to(query.dtype)
88
 
89
  # Apply enhancement if enabled
 
117
  hidden_states = args[0]
118
  num_frames = hidden_states.shape[2]
119
  set_num_frames(num_frames)
120
+ return args, kwargs