arthur-qiu commited on
Commit
9330af0
·
1 Parent(s): f833804

fix filter

Browse files
Files changed (2) hide show
  1. scale_attention.py +7 -10
  2. scale_attention_turbo.py +7 -10
scale_attention.py CHANGED
@@ -9,15 +9,15 @@ def gaussian_kernel(kernel_size=3, sigma=1.0, channels=3):
9
  x_coord = torch.arange(kernel_size)
10
  gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2))
11
  gaussian_1d = gaussian_1d / gaussian_1d.sum()
12
- gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :]
13
- kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1)
14
 
15
  return kernel
16
 
17
  def gaussian_filter(latents, kernel_size=3, sigma=1.0):
18
- channels = latents.shape[1]
19
  kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype)
20
- blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels)
21
 
22
  return blurred_latents
23
 
@@ -159,7 +159,6 @@ def scale_forward(
159
  count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
160
  attn_output = torch.where(count>0, value/count, value)
161
 
162
- attn_output = rearrange(attn_output, 'bh h w d -> bh d h w')
163
  gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
164
 
165
  attn_output_global = self.attn1(
@@ -168,12 +167,12 @@ def scale_forward(
168
  attention_mask=attention_mask,
169
  **cross_attention_kwargs,
170
  )
171
- attn_output_global = rearrange(attn_output_global, 'bh (h w) d -> bh d h w', h = latent_h)
172
 
173
  gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
174
 
175
  attn_output = gaussian_local + (attn_output_global - gaussian_global)
176
- attn_output = rearrange(attn_output, 'bh d h w -> bh (h w) d')
177
 
178
  elif fourg_window:
179
  norm_hidden_states = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
@@ -199,7 +198,6 @@ def scale_forward(
199
  count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
200
  attn_output = torch.where(count>0, value/count, value)
201
 
202
- attn_output = rearrange(attn_output, 'bh h w d -> bh d h w')
203
  gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
204
 
205
  value = torch.zeros_like(norm_hidden_states)
@@ -221,11 +219,10 @@ def scale_forward(
221
 
222
  attn_output_global = torch.where(count>0, value/count, value)
223
 
224
- attn_output_global = rearrange(attn_output_global, 'bh h w d -> bh d h w')
225
  gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
226
 
227
  attn_output = gaussian_local + (attn_output_global - gaussian_global)
228
- attn_output = rearrange(attn_output, 'bh d h w -> bh (h w) d')
229
 
230
  else:
231
  attn_output = self.attn1(
 
9
  x_coord = torch.arange(kernel_size)
10
  gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2))
11
  gaussian_1d = gaussian_1d / gaussian_1d.sum()
12
+ gaussian_3d = gaussian_1d[:, None, None] * gaussian_1d[None, :, None] * gaussian_1d[None, None, :]
13
+ kernel = gaussian_3d[None, None, :, :, :].repeat(channels, 1, 1, 1, 1)
14
 
15
  return kernel
16
 
17
  def gaussian_filter(latents, kernel_size=3, sigma=1.0):
18
+ channels = latents.shape[0]
19
  kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype)
20
+ blurred_latents = F.conv3d(latents.unsqueeze(0), kernel, padding=kernel_size//2, groups=channels)[0]
21
 
22
  return blurred_latents
23
 
 
159
  count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
160
  attn_output = torch.where(count>0, value/count, value)
161
 
 
162
  gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
163
 
164
  attn_output_global = self.attn1(
 
167
  attention_mask=attention_mask,
168
  **cross_attention_kwargs,
169
  )
170
+ attn_output_global = rearrange(attn_output_global, 'bh (h w) d -> bh h w d', h = latent_h)
171
 
172
  gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
173
 
174
  attn_output = gaussian_local + (attn_output_global - gaussian_global)
175
+ attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d')
176
 
177
  elif fourg_window:
178
  norm_hidden_states = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
 
198
  count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
199
  attn_output = torch.where(count>0, value/count, value)
200
 
 
201
  gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
202
 
203
  value = torch.zeros_like(norm_hidden_states)
 
219
 
220
  attn_output_global = torch.where(count>0, value/count, value)
221
 
 
222
  gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
223
 
224
  attn_output = gaussian_local + (attn_output_global - gaussian_global)
225
+ attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d')
226
 
227
  else:
228
  attn_output = self.attn1(
scale_attention_turbo.py CHANGED
@@ -9,15 +9,15 @@ def gaussian_kernel(kernel_size=3, sigma=1.0, channels=3):
9
  x_coord = torch.arange(kernel_size)
10
  gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2))
11
  gaussian_1d = gaussian_1d / gaussian_1d.sum()
12
- gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :]
13
- kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1)
14
 
15
  return kernel
16
 
17
  def gaussian_filter(latents, kernel_size=3, sigma=1.0):
18
- channels = latents.shape[1]
19
  kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype)
20
- blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels)
21
 
22
  return blurred_latents
23
 
@@ -159,7 +159,6 @@ def scale_forward(
159
  count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
160
  attn_output = torch.where(count>0, value/count, value)
161
 
162
- attn_output = rearrange(attn_output, 'bh h w d -> bh d h w')
163
  gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
164
 
165
  attn_output_global = self.attn1(
@@ -168,12 +167,12 @@ def scale_forward(
168
  attention_mask=attention_mask,
169
  **cross_attention_kwargs,
170
  )
171
- attn_output_global = rearrange(attn_output_global, 'bh (h w) d -> bh d h w', h = latent_h)
172
 
173
  gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
174
 
175
  attn_output = gaussian_local + (attn_output_global - gaussian_global)
176
- attn_output = rearrange(attn_output, 'bh d h w -> bh (h w) d')
177
 
178
  elif fourg_window:
179
  norm_hidden_states = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
@@ -199,7 +198,6 @@ def scale_forward(
199
  count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
200
  attn_output = torch.where(count>0, value/count, value)
201
 
202
- attn_output = rearrange(attn_output, 'bh h w d -> bh d h w')
203
  gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
204
 
205
  value = torch.zeros_like(norm_hidden_states)
@@ -221,11 +219,10 @@ def scale_forward(
221
 
222
  attn_output_global = torch.where(count>0, value/count, value)
223
 
224
- attn_output_global = rearrange(attn_output_global, 'bh h w d -> bh d h w')
225
  gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
226
 
227
  attn_output = gaussian_local + (attn_output_global - gaussian_global)
228
- attn_output = rearrange(attn_output, 'bh d h w -> bh (h w) d')
229
 
230
  else:
231
  attn_output = self.attn1(
 
9
  x_coord = torch.arange(kernel_size)
10
  gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2))
11
  gaussian_1d = gaussian_1d / gaussian_1d.sum()
12
+ gaussian_3d = gaussian_1d[:, None, None] * gaussian_1d[None, :, None] * gaussian_1d[None, None, :]
13
+ kernel = gaussian_3d[None, None, :, :, :].repeat(channels, 1, 1, 1, 1)
14
 
15
  return kernel
16
 
17
  def gaussian_filter(latents, kernel_size=3, sigma=1.0):
18
+ channels = latents.shape[0]
19
  kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype)
20
+ blurred_latents = F.conv3d(latents.unsqueeze(0), kernel, padding=kernel_size//2, groups=channels)[0]
21
 
22
  return blurred_latents
23
 
 
159
  count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
160
  attn_output = torch.where(count>0, value/count, value)
161
 
 
162
  gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
163
 
164
  attn_output_global = self.attn1(
 
167
  attention_mask=attention_mask,
168
  **cross_attention_kwargs,
169
  )
170
+ attn_output_global = rearrange(attn_output_global, 'bh (h w) d -> bh h w d', h = latent_h)
171
 
172
  gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
173
 
174
  attn_output = gaussian_local + (attn_output_global - gaussian_global)
175
+ attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d')
176
 
177
  elif fourg_window:
178
  norm_hidden_states = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
 
198
  count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
199
  attn_output = torch.where(count>0, value/count, value)
200
 
 
201
  gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
202
 
203
  value = torch.zeros_like(norm_hidden_states)
 
219
 
220
  attn_output_global = torch.where(count>0, value/count, value)
221
 
 
222
  gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
223
 
224
  attn_output = gaussian_local + (attn_output_global - gaussian_global)
225
+ attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d')
226
 
227
  else:
228
  attn_output = self.attn1(