debisoft commited on
Commit
e18412d
·
1 Parent(s): 9eca6d8
Files changed (1) hide show
  1. app.py +30 -1
app.py CHANGED
@@ -161,9 +161,38 @@ def sample_ddim(n_sample, n=20):
161
  intermediate = np.stack(intermediate)
162
  return samples, intermediate
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  def greet(input):
165
  steps = int(input)
166
- samples, intermediate = sample_ddim(32, n=steps)
 
 
167
  #response = transform2(transform(np.moveaxis(samples.detach().cpu().numpy(),1,3)[-1]))
168
  #response2 = transform2(transform(np.moveaxis(samples.detach().cpu().numpy(),1,3)[1]))
169
  #response = im.fromarray(intermediate[24][0][1]).convert("RGB")
 
161
  intermediate = np.stack(intermediate)
162
  return samples, intermediate
163
 
164
+ # load in model weights and set to eval mode
165
+ nn_model.load_state_dict(torch.load(f"{save_dir}/context_model_31.pth", map_location=device))
166
+ nn_model.eval()
167
+ print("Loaded in Context Model")
168
+
169
+ # fast sampling algorithm with context
170
+ @torch.no_grad()
171
+ def sample_ddim_context(n_sample, context, n=20):
172
+ # x_T ~ N(0, 1), sample initial noise
173
+ samples = torch.randn(n_sample, 3, height, height).to(device)
174
+
175
+ # array to keep track of generated steps for plotting
176
+ intermediate = []
177
+ step_size = timesteps // n
178
+ for i in range(timesteps, 0, -step_size):
179
+ print(f'sampling timestep {i:3d}', end='\r')
180
+
181
+ # reshape time tensor
182
+ t = torch.tensor([i / timesteps])[:, None, None, None].to(device)
183
+
184
+ eps = nn_model(samples, t, c=context) # predict noise e_(x_t,t)
185
+ samples = denoise_ddim(samples, i, i - step_size, eps)
186
+ intermediate.append(samples.detach().cpu().numpy())
187
+
188
+ intermediate = np.stack(intermediate)
189
+ return samples, intermediate
190
+
191
  def greet(input):
192
  steps = int(input)
193
+ #samples, intermediate = sample_ddim(32, n=steps)
194
+ ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()
195
+ samples, intermediate = sample_ddim_context(32, ctx)
196
  #response = transform2(transform(np.moveaxis(samples.detach().cpu().numpy(),1,3)[-1]))
197
  #response2 = transform2(transform(np.moveaxis(samples.detach().cpu().numpy(),1,3)[1]))
198
  #response = im.fromarray(intermediate[24][0][1]).convert("RGB")