arthur-qiu commited on
Commit
cca304f
·
1 Parent(s): 2cc3d41

clean code

Browse files
Files changed (2) hide show
  1. pipeline_freescale.py +0 -121
  2. pipeline_freescale_turbo.py +0 -121
pipeline_freescale.py CHANGED
@@ -55,127 +55,6 @@ EXAMPLE_DOC_STRING = """
55
  ```
56
  """
57
 
58
- def default(val, d):
59
- if exists(val):
60
- return val
61
- return d() if isfunction(d) else d
62
-
63
- def exists(val):
64
- return val is not None
65
-
66
- def extract_into_tensor(a, t, x_shape):
67
- b, *_ = t.shape
68
- out = a.gather(-1, t)
69
- return out.reshape(b, *((1,) * (len(x_shape) - 1)))
70
-
71
- def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
72
- if schedule == "linear":
73
- betas = (
74
- torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
75
- )
76
- elif schedule == "cosine":
77
- timesteps = (
78
- torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
79
- )
80
- alphas = timesteps / (1 + cosine_s) * np.pi / 2
81
- alphas = torch.cos(alphas).pow(2)
82
- alphas = alphas / alphas[0]
83
- betas = 1 - alphas[1:] / alphas[:-1]
84
- betas = np.clip(betas, a_min=0, a_max=0.999)
85
- elif schedule == "sqrt_linear":
86
- betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
87
- elif schedule == "sqrt":
88
- betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
89
- else:
90
- raise ValueError(f"schedule '{schedule}' unknown.")
91
- return betas.numpy()
92
-
93
- to_torch = partial(torch.tensor, dtype=torch.float16)
94
- betas = make_beta_schedule("linear", 1000, linear_start=0.00085, linear_end=0.012)
95
- alphas = 1. - betas
96
- alphas_cumprod = np.cumprod(alphas, axis=0)
97
- sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod))
98
- sqrt_one_minus_alphas_cumprod = to_torch(np.sqrt(1. - alphas_cumprod))
99
-
100
- def q_sample(x_start, t, init_noise_sigma = 1.0, noise=None, device=None):
101
- noise = default(noise, lambda: torch.randn_like(x_start)).to(device) * init_noise_sigma
102
- return (extract_into_tensor(sqrt_alphas_cumprod.to(device), t, x_start.shape) * x_start +
103
- extract_into_tensor(sqrt_one_minus_alphas_cumprod.to(device), t, x_start.shape) * noise)
104
-
105
- def get_views(height, width, h_window_size=128, w_window_size=128, h_window_stride=64, w_window_stride=64, vae_scale_factor=8):
106
- height //= vae_scale_factor
107
- width //= vae_scale_factor
108
- num_blocks_height = int((height - h_window_size) / h_window_stride - 1e-6) + 2 if height > h_window_size else 1
109
- num_blocks_width = int((width - w_window_size) / w_window_stride - 1e-6) + 2 if width > w_window_size else 1
110
- total_num_blocks = int(num_blocks_height * num_blocks_width)
111
- views = []
112
- for i in range(total_num_blocks):
113
- h_start = int((i // num_blocks_width) * h_window_stride)
114
- h_end = h_start + h_window_size
115
- w_start = int((i % num_blocks_width) * w_window_stride)
116
- w_end = w_start + w_window_size
117
-
118
- if h_end > height:
119
- h_start = int(h_start + height - h_end)
120
- h_end = int(height)
121
- if w_end > width:
122
- w_start = int(w_start + width - w_end)
123
- w_end = int(width)
124
- if h_start < 0:
125
- h_end = int(h_end - h_start)
126
- h_start = 0
127
- if w_start < 0:
128
- w_end = int(w_end - w_start)
129
- w_start = 0
130
-
131
- random_jitter = True
132
- if random_jitter:
133
- h_jitter_range = (h_window_size - h_window_stride) // 4
134
- w_jitter_range = (w_window_size - w_window_stride) // 4
135
- h_jitter = 0
136
- w_jitter = 0
137
-
138
- if (w_start != 0) and (w_end != width):
139
- w_jitter = random.randint(-w_jitter_range, w_jitter_range)
140
- elif (w_start == 0) and (w_end != width):
141
- w_jitter = random.randint(-w_jitter_range, 0)
142
- elif (w_start != 0) and (w_end == width):
143
- w_jitter = random.randint(0, w_jitter_range)
144
- if (h_start != 0) and (h_end != height):
145
- h_jitter = random.randint(-h_jitter_range, h_jitter_range)
146
- elif (h_start == 0) and (h_end != height):
147
- h_jitter = random.randint(-h_jitter_range, 0)
148
- elif (h_start != 0) and (h_end == height):
149
- h_jitter = random.randint(0, h_jitter_range)
150
- h_start += (h_jitter + h_jitter_range)
151
- h_end += (h_jitter + h_jitter_range)
152
- w_start += (w_jitter + w_jitter_range)
153
- w_end += (w_jitter + w_jitter_range)
154
-
155
- views.append((h_start, h_end, w_start, w_end))
156
- return views
157
-
158
- def gaussian_kernel(kernel_size=3, sigma=1.0, channels=3):
159
- x_coord = torch.arange(kernel_size)
160
- gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2))
161
- gaussian_1d = gaussian_1d / gaussian_1d.sum()
162
- gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :]
163
- kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1)
164
-
165
- return kernel
166
-
167
- def gaussian_filter(latents, kernel_size=3, sigma=1.0):
168
- channels = latents.shape[1]
169
- kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype)
170
- if len(latents.shape) == 5:
171
- b = latents.shape[0]
172
- latents = rearrange(latents, 'b c t i j -> (b t) c i j')
173
- blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels)
174
- blurred_latents = rearrange(blurred_latents, '(b t) c i j -> b c t i j', b=b)
175
- else:
176
- blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels)
177
-
178
- return blurred_latents
179
 
180
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
181
  def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
 
55
  ```
56
  """
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
60
  def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
pipeline_freescale_turbo.py CHANGED
@@ -55,127 +55,6 @@ EXAMPLE_DOC_STRING = """
55
  ```
56
  """
57
 
58
- def default(val, d):
59
- if exists(val):
60
- return val
61
- return d() if isfunction(d) else d
62
-
63
- def exists(val):
64
- return val is not None
65
-
66
- def extract_into_tensor(a, t, x_shape):
67
- b, *_ = t.shape
68
- out = a.gather(-1, t)
69
- return out.reshape(b, *((1,) * (len(x_shape) - 1)))
70
-
71
- def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
72
- if schedule == "linear":
73
- betas = (
74
- torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
75
- )
76
- elif schedule == "cosine":
77
- timesteps = (
78
- torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
79
- )
80
- alphas = timesteps / (1 + cosine_s) * np.pi / 2
81
- alphas = torch.cos(alphas).pow(2)
82
- alphas = alphas / alphas[0]
83
- betas = 1 - alphas[1:] / alphas[:-1]
84
- betas = np.clip(betas, a_min=0, a_max=0.999)
85
- elif schedule == "sqrt_linear":
86
- betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
87
- elif schedule == "sqrt":
88
- betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
89
- else:
90
- raise ValueError(f"schedule '{schedule}' unknown.")
91
- return betas.numpy()
92
-
93
- to_torch = partial(torch.tensor, dtype=torch.float16)
94
- betas = make_beta_schedule("linear", 1000, linear_start=0.00085, linear_end=0.012)
95
- alphas = 1. - betas
96
- alphas_cumprod = np.cumprod(alphas, axis=0)
97
- sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod))
98
- sqrt_one_minus_alphas_cumprod = to_torch(np.sqrt(1. - alphas_cumprod))
99
-
100
- def q_sample(x_start, t, init_noise_sigma = 1.0, noise=None, device=None):
101
- noise = default(noise, lambda: torch.randn_like(x_start)).to(device) * init_noise_sigma
102
- return (extract_into_tensor(sqrt_alphas_cumprod.to(device), t, x_start.shape) * x_start +
103
- extract_into_tensor(sqrt_one_minus_alphas_cumprod.to(device), t, x_start.shape) * noise)
104
-
105
- def get_views(height, width, h_window_size=128, w_window_size=128, h_window_stride=64, w_window_stride=64, vae_scale_factor=8):
106
- height //= vae_scale_factor
107
- width //= vae_scale_factor
108
- num_blocks_height = int((height - h_window_size) / h_window_stride - 1e-6) + 2 if height > h_window_size else 1
109
- num_blocks_width = int((width - w_window_size) / w_window_stride - 1e-6) + 2 if width > w_window_size else 1
110
- total_num_blocks = int(num_blocks_height * num_blocks_width)
111
- views = []
112
- for i in range(total_num_blocks):
113
- h_start = int((i // num_blocks_width) * h_window_stride)
114
- h_end = h_start + h_window_size
115
- w_start = int((i % num_blocks_width) * w_window_stride)
116
- w_end = w_start + w_window_size
117
-
118
- if h_end > height:
119
- h_start = int(h_start + height - h_end)
120
- h_end = int(height)
121
- if w_end > width:
122
- w_start = int(w_start + width - w_end)
123
- w_end = int(width)
124
- if h_start < 0:
125
- h_end = int(h_end - h_start)
126
- h_start = 0
127
- if w_start < 0:
128
- w_end = int(w_end - w_start)
129
- w_start = 0
130
-
131
- random_jitter = True
132
- if random_jitter:
133
- h_jitter_range = (h_window_size - h_window_stride) // 4
134
- w_jitter_range = (w_window_size - w_window_stride) // 4
135
- h_jitter = 0
136
- w_jitter = 0
137
-
138
- if (w_start != 0) and (w_end != width):
139
- w_jitter = random.randint(-w_jitter_range, w_jitter_range)
140
- elif (w_start == 0) and (w_end != width):
141
- w_jitter = random.randint(-w_jitter_range, 0)
142
- elif (w_start != 0) and (w_end == width):
143
- w_jitter = random.randint(0, w_jitter_range)
144
- if (h_start != 0) and (h_end != height):
145
- h_jitter = random.randint(-h_jitter_range, h_jitter_range)
146
- elif (h_start == 0) and (h_end != height):
147
- h_jitter = random.randint(-h_jitter_range, 0)
148
- elif (h_start != 0) and (h_end == height):
149
- h_jitter = random.randint(0, h_jitter_range)
150
- h_start += (h_jitter + h_jitter_range)
151
- h_end += (h_jitter + h_jitter_range)
152
- w_start += (w_jitter + w_jitter_range)
153
- w_end += (w_jitter + w_jitter_range)
154
-
155
- views.append((h_start, h_end, w_start, w_end))
156
- return views
157
-
158
- def gaussian_kernel(kernel_size=3, sigma=1.0, channels=3):
159
- x_coord = torch.arange(kernel_size)
160
- gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2))
161
- gaussian_1d = gaussian_1d / gaussian_1d.sum()
162
- gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :]
163
- kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1)
164
-
165
- return kernel
166
-
167
- def gaussian_filter(latents, kernel_size=3, sigma=1.0):
168
- channels = latents.shape[1]
169
- kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype)
170
- if len(latents.shape) == 5:
171
- b = latents.shape[0]
172
- latents = rearrange(latents, 'b c t i j -> (b t) c i j')
173
- blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels)
174
- blurred_latents = rearrange(blurred_latents, '(b t) c i j -> b c t i j', b=b)
175
- else:
176
- blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels)
177
-
178
- return blurred_latents
179
 
180
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
181
  def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
 
55
  ```
56
  """
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
60
  def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):