Spaces:
Sleeping
Sleeping
arthur-qiu
commited on
Commit
·
9330af0
1
Parent(s):
f833804
fix filter
Browse files- scale_attention.py +7 -10
- scale_attention_turbo.py +7 -10
scale_attention.py
CHANGED
@@ -9,15 +9,15 @@ def gaussian_kernel(kernel_size=3, sigma=1.0, channels=3):
|
|
9 |
x_coord = torch.arange(kernel_size)
|
10 |
gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2))
|
11 |
gaussian_1d = gaussian_1d / gaussian_1d.sum()
|
12 |
-
|
13 |
-
kernel =
|
14 |
|
15 |
return kernel
|
16 |
|
17 |
def gaussian_filter(latents, kernel_size=3, sigma=1.0):
|
18 |
-
channels = latents.shape[
|
19 |
kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype)
|
20 |
-
blurred_latents = F.
|
21 |
|
22 |
return blurred_latents
|
23 |
|
@@ -159,7 +159,6 @@ def scale_forward(
|
|
159 |
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
160 |
attn_output = torch.where(count>0, value/count, value)
|
161 |
|
162 |
-
attn_output = rearrange(attn_output, 'bh h w d -> bh d h w')
|
163 |
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
164 |
|
165 |
attn_output_global = self.attn1(
|
@@ -168,12 +167,12 @@ def scale_forward(
|
|
168 |
attention_mask=attention_mask,
|
169 |
**cross_attention_kwargs,
|
170 |
)
|
171 |
-
attn_output_global = rearrange(attn_output_global, 'bh (h w) d -> bh
|
172 |
|
173 |
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
174 |
|
175 |
attn_output = gaussian_local + (attn_output_global - gaussian_global)
|
176 |
-
attn_output = rearrange(attn_output, 'bh
|
177 |
|
178 |
elif fourg_window:
|
179 |
norm_hidden_states = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
|
@@ -199,7 +198,6 @@ def scale_forward(
|
|
199 |
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
200 |
attn_output = torch.where(count>0, value/count, value)
|
201 |
|
202 |
-
attn_output = rearrange(attn_output, 'bh h w d -> bh d h w')
|
203 |
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
204 |
|
205 |
value = torch.zeros_like(norm_hidden_states)
|
@@ -221,11 +219,10 @@ def scale_forward(
|
|
221 |
|
222 |
attn_output_global = torch.where(count>0, value/count, value)
|
223 |
|
224 |
-
attn_output_global = rearrange(attn_output_global, 'bh h w d -> bh d h w')
|
225 |
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
226 |
|
227 |
attn_output = gaussian_local + (attn_output_global - gaussian_global)
|
228 |
-
attn_output = rearrange(attn_output, 'bh
|
229 |
|
230 |
else:
|
231 |
attn_output = self.attn1(
|
|
|
9 |
x_coord = torch.arange(kernel_size)
|
10 |
gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2))
|
11 |
gaussian_1d = gaussian_1d / gaussian_1d.sum()
|
12 |
+
gaussian_3d = gaussian_1d[:, None, None] * gaussian_1d[None, :, None] * gaussian_1d[None, None, :]
|
13 |
+
kernel = gaussian_3d[None, None, :, :, :].repeat(channels, 1, 1, 1, 1)
|
14 |
|
15 |
return kernel
|
16 |
|
17 |
def gaussian_filter(latents, kernel_size=3, sigma=1.0):
|
18 |
+
channels = latents.shape[0]
|
19 |
kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype)
|
20 |
+
blurred_latents = F.conv3d(latents.unsqueeze(0), kernel, padding=kernel_size//2, groups=channels)[0]
|
21 |
|
22 |
return blurred_latents
|
23 |
|
|
|
159 |
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
160 |
attn_output = torch.where(count>0, value/count, value)
|
161 |
|
|
|
162 |
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
163 |
|
164 |
attn_output_global = self.attn1(
|
|
|
167 |
attention_mask=attention_mask,
|
168 |
**cross_attention_kwargs,
|
169 |
)
|
170 |
+
attn_output_global = rearrange(attn_output_global, 'bh (h w) d -> bh h w d', h = latent_h)
|
171 |
|
172 |
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
173 |
|
174 |
attn_output = gaussian_local + (attn_output_global - gaussian_global)
|
175 |
+
attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d')
|
176 |
|
177 |
elif fourg_window:
|
178 |
norm_hidden_states = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
|
|
|
198 |
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
199 |
attn_output = torch.where(count>0, value/count, value)
|
200 |
|
|
|
201 |
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
202 |
|
203 |
value = torch.zeros_like(norm_hidden_states)
|
|
|
219 |
|
220 |
attn_output_global = torch.where(count>0, value/count, value)
|
221 |
|
|
|
222 |
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
223 |
|
224 |
attn_output = gaussian_local + (attn_output_global - gaussian_global)
|
225 |
+
attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d')
|
226 |
|
227 |
else:
|
228 |
attn_output = self.attn1(
|
scale_attention_turbo.py
CHANGED
@@ -9,15 +9,15 @@ def gaussian_kernel(kernel_size=3, sigma=1.0, channels=3):
|
|
9 |
x_coord = torch.arange(kernel_size)
|
10 |
gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2))
|
11 |
gaussian_1d = gaussian_1d / gaussian_1d.sum()
|
12 |
-
|
13 |
-
kernel =
|
14 |
|
15 |
return kernel
|
16 |
|
17 |
def gaussian_filter(latents, kernel_size=3, sigma=1.0):
|
18 |
-
channels = latents.shape[
|
19 |
kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype)
|
20 |
-
blurred_latents = F.
|
21 |
|
22 |
return blurred_latents
|
23 |
|
@@ -159,7 +159,6 @@ def scale_forward(
|
|
159 |
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
160 |
attn_output = torch.where(count>0, value/count, value)
|
161 |
|
162 |
-
attn_output = rearrange(attn_output, 'bh h w d -> bh d h w')
|
163 |
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
164 |
|
165 |
attn_output_global = self.attn1(
|
@@ -168,12 +167,12 @@ def scale_forward(
|
|
168 |
attention_mask=attention_mask,
|
169 |
**cross_attention_kwargs,
|
170 |
)
|
171 |
-
attn_output_global = rearrange(attn_output_global, 'bh (h w) d -> bh
|
172 |
|
173 |
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
174 |
|
175 |
attn_output = gaussian_local + (attn_output_global - gaussian_global)
|
176 |
-
attn_output = rearrange(attn_output, 'bh
|
177 |
|
178 |
elif fourg_window:
|
179 |
norm_hidden_states = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
|
@@ -199,7 +198,6 @@ def scale_forward(
|
|
199 |
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
200 |
attn_output = torch.where(count>0, value/count, value)
|
201 |
|
202 |
-
attn_output = rearrange(attn_output, 'bh h w d -> bh d h w')
|
203 |
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
204 |
|
205 |
value = torch.zeros_like(norm_hidden_states)
|
@@ -221,11 +219,10 @@ def scale_forward(
|
|
221 |
|
222 |
attn_output_global = torch.where(count>0, value/count, value)
|
223 |
|
224 |
-
attn_output_global = rearrange(attn_output_global, 'bh h w d -> bh d h w')
|
225 |
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
226 |
|
227 |
attn_output = gaussian_local + (attn_output_global - gaussian_global)
|
228 |
-
attn_output = rearrange(attn_output, 'bh
|
229 |
|
230 |
else:
|
231 |
attn_output = self.attn1(
|
|
|
9 |
x_coord = torch.arange(kernel_size)
|
10 |
gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2))
|
11 |
gaussian_1d = gaussian_1d / gaussian_1d.sum()
|
12 |
+
gaussian_3d = gaussian_1d[:, None, None] * gaussian_1d[None, :, None] * gaussian_1d[None, None, :]
|
13 |
+
kernel = gaussian_3d[None, None, :, :, :].repeat(channels, 1, 1, 1, 1)
|
14 |
|
15 |
return kernel
|
16 |
|
17 |
def gaussian_filter(latents, kernel_size=3, sigma=1.0):
|
18 |
+
channels = latents.shape[0]
|
19 |
kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype)
|
20 |
+
blurred_latents = F.conv3d(latents.unsqueeze(0), kernel, padding=kernel_size//2, groups=channels)[0]
|
21 |
|
22 |
return blurred_latents
|
23 |
|
|
|
159 |
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
160 |
attn_output = torch.where(count>0, value/count, value)
|
161 |
|
|
|
162 |
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
163 |
|
164 |
attn_output_global = self.attn1(
|
|
|
167 |
attention_mask=attention_mask,
|
168 |
**cross_attention_kwargs,
|
169 |
)
|
170 |
+
attn_output_global = rearrange(attn_output_global, 'bh (h w) d -> bh h w d', h = latent_h)
|
171 |
|
172 |
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
173 |
|
174 |
attn_output = gaussian_local + (attn_output_global - gaussian_global)
|
175 |
+
attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d')
|
176 |
|
177 |
elif fourg_window:
|
178 |
norm_hidden_states = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
|
|
|
198 |
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
199 |
attn_output = torch.where(count>0, value/count, value)
|
200 |
|
|
|
201 |
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
202 |
|
203 |
value = torch.zeros_like(norm_hidden_states)
|
|
|
219 |
|
220 |
attn_output_global = torch.where(count>0, value/count, value)
|
221 |
|
|
|
222 |
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
223 |
|
224 |
attn_output = gaussian_local + (attn_output_global - gaussian_global)
|
225 |
+
attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d')
|
226 |
|
227 |
else:
|
228 |
attn_output = self.attn1(
|