Spaces:
Running
Running
sample_ddpm_context
Browse files
app.py
CHANGED
@@ -198,7 +198,7 @@ def denoise_add_noise(x, t, pred_noise, z=None):
|
|
198 |
|
199 |
# sample using standard algorithm
|
200 |
@torch.no_grad()
|
201 |
-
def sample_ddpm(n_sample, save_rate=20):
|
202 |
# x_T ~ N(0, 1), sample initial noise
|
203 |
samples = torch.randn(n_sample, 3, height, height).to(device)
|
204 |
|
@@ -221,6 +221,30 @@ def sample_ddpm(n_sample, save_rate=20):
|
|
221 |
intermediate = np.stack(intermediate)
|
222 |
return samples, intermediate
|
223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
def greet(input):
|
225 |
steps = int(input)
|
226 |
|
@@ -233,7 +257,7 @@ def greet(input):
|
|
233 |
ctx = torch.from_numpy(mtx_2d).to(device=device).float()
|
234 |
|
235 |
#samples, intermediate = sample_ddim_context(32, ctx, n=steps)
|
236 |
-
samples, intermediate =
|
237 |
|
238 |
#samples, intermediate = sample_ddim(32, n=steps)
|
239 |
#ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()
|
|
|
198 |
|
199 |
# sample using standard algorithm
|
200 |
@torch.no_grad()
|
201 |
+
def sample_ddpm(n_sample, context, save_rate=20):
|
202 |
# x_T ~ N(0, 1), sample initial noise
|
203 |
samples = torch.randn(n_sample, 3, height, height).to(device)
|
204 |
|
|
|
221 |
intermediate = np.stack(intermediate)
|
222 |
return samples, intermediate
|
223 |
|
224 |
+
@torch.no_grad()
|
225 |
+
def sample_ddpm_context(n_sample, save_rate=20):
|
226 |
+
# x_T ~ N(0, 1), sample initial noise
|
227 |
+
samples = torch.randn(n_sample, 3, height, height).to(device)
|
228 |
+
|
229 |
+
# array to keep track of generated steps for plotting
|
230 |
+
intermediate = []
|
231 |
+
for i in range(timesteps, 0, -1):
|
232 |
+
print(f'sampling timestep {i:3d}', end='\r')
|
233 |
+
|
234 |
+
# reshape time tensor
|
235 |
+
t = torch.tensor([i / timesteps])[:, None, None, None].to(device)
|
236 |
+
|
237 |
+
# sample some random noise to inject back in. For i = 1, don't add back in noise
|
238 |
+
z = torch.randn_like(samples) if i > 1 else 0
|
239 |
+
|
240 |
+
eps = nn_model(samples, t, c=context) # predict noise e_(x_t,t)
|
241 |
+
samples = denoise_add_noise(samples, i, eps, z)
|
242 |
+
if i % save_rate ==0 or i==timesteps or i<8:
|
243 |
+
intermediate.append(samples.detach().cpu().numpy())
|
244 |
+
|
245 |
+
intermediate = np.stack(intermediate)
|
246 |
+
return samples, intermediate
|
247 |
+
|
248 |
def greet(input):
|
249 |
steps = int(input)
|
250 |
|
|
|
257 |
ctx = torch.from_numpy(mtx_2d).to(device=device).float()
|
258 |
|
259 |
#samples, intermediate = sample_ddim_context(32, ctx, n=steps)
|
260 |
+
samples, intermediate = sample_ddpm_context(32, ctx, steps)
|
261 |
|
262 |
#samples, intermediate = sample_ddim(32, n=steps)
|
263 |
#ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()
|