Spaces:
Running
Running
ddpm
Browse files
app.py
CHANGED
@@ -188,11 +188,45 @@ def sample_ddim_context(n_sample, context, n=20):
|
|
188 |
intermediate = np.stack(intermediate)
|
189 |
return samples, intermediate
|
190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
def greet(input):
|
192 |
steps = int(input)
|
193 |
#samples, intermediate = sample_ddim(32, n=steps)
|
194 |
-
ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()
|
195 |
-
samples, intermediate = sample_ddim_context(32, ctx, steps)
|
|
|
196 |
#response = transform2(transform(np.moveaxis(samples.detach().cpu().numpy(),1,3)[-1]))
|
197 |
#response2 = transform2(transform(np.moveaxis(samples.detach().cpu().numpy(),1,3)[1]))
|
198 |
#response = im.fromarray(intermediate[24][0][1]).convert("RGB")
|
|
|
188 |
intermediate = np.stack(intermediate)
|
189 |
return samples, intermediate
|
190 |
|
191 |
+
# helper function; removes the predicted noise (but adds some noise back in to avoid collapse)
|
192 |
+
def denoise_add_noise(x, t, pred_noise, z=None):
|
193 |
+
if z is None:
|
194 |
+
z = torch.randn_like(x)
|
195 |
+
noise = b_t.sqrt()[t] * z
|
196 |
+
mean = (x - pred_noise * ((1 - a_t[t]) / (1 - ab_t[t]).sqrt())) / a_t[t].sqrt()
|
197 |
+
return mean + noise
|
198 |
+
|
199 |
+
# sample using standard algorithm
|
200 |
+
@torch.no_grad()
|
201 |
+
def sample_ddpm(n_sample, save_rate=20):
|
202 |
+
# x_T ~ N(0, 1), sample initial noise
|
203 |
+
samples = torch.randn(n_sample, 3, height, height).to(device)
|
204 |
+
|
205 |
+
# array to keep track of generated steps for plotting
|
206 |
+
intermediate = []
|
207 |
+
for i in range(timesteps, 0, -1):
|
208 |
+
print(f'sampling timestep {i:3d}', end='\r')
|
209 |
+
|
210 |
+
# reshape time tensor
|
211 |
+
t = torch.tensor([i / timesteps])[:, None, None, None].to(device)
|
212 |
+
|
213 |
+
# sample some random noise to inject back in. For i = 1, don't add back in noise
|
214 |
+
z = torch.randn_like(samples) if i > 1 else 0
|
215 |
+
|
216 |
+
eps = nn_model(samples, t) # predict noise e_(x_t,t)
|
217 |
+
samples = denoise_add_noise(samples, i, eps, z)
|
218 |
+
if i % save_rate ==0 or i==timesteps or i<8:
|
219 |
+
intermediate.append(samples.detach().cpu().numpy())
|
220 |
+
|
221 |
+
intermediate = np.stack(intermediate)
|
222 |
+
return samples, intermediate
|
223 |
+
|
224 |
def greet(input):
|
225 |
steps = int(input)
|
226 |
#samples, intermediate = sample_ddim(32, n=steps)
|
227 |
+
#ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()
|
228 |
+
#samples, intermediate = sample_ddim_context(32, ctx, steps)
|
229 |
+
samples, intermediate = sample_ddpm(32, )
|
230 |
#response = transform2(transform(np.moveaxis(samples.detach().cpu().numpy(),1,3)[-1]))
|
231 |
#response2 = transform2(transform(np.moveaxis(samples.detach().cpu().numpy(),1,3)[1]))
|
232 |
#response = im.fromarray(intermediate[24][0][1]).convert("RGB")
|