Spaces:
Sleeping
Sleeping
context
Browse files
app.py
CHANGED
@@ -161,9 +161,38 @@ def sample_ddim(n_sample, n=20):
|
|
161 |
intermediate = np.stack(intermediate)
|
162 |
return samples, intermediate
|
163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
def greet(input):
|
165 |
steps = int(input)
|
166 |
-
samples, intermediate = sample_ddim(32, n=steps)
|
|
|
|
|
167 |
#response = transform2(transform(np.moveaxis(samples.detach().cpu().numpy(),1,3)[-1]))
|
168 |
#response2 = transform2(transform(np.moveaxis(samples.detach().cpu().numpy(),1,3)[1]))
|
169 |
#response = im.fromarray(intermediate[24][0][1]).convert("RGB")
|
|
|
161 |
intermediate = np.stack(intermediate)
|
162 |
return samples, intermediate
|
163 |
|
164 |
+
# load in model weights and set to eval mode
|
165 |
+
nn_model.load_state_dict(torch.load(f"{save_dir}/context_model_31.pth", map_location=device))
|
166 |
+
nn_model.eval()
|
167 |
+
print("Loaded in Context Model")
|
168 |
+
|
169 |
+
# fast sampling algorithm with context
|
170 |
+
@torch.no_grad()
|
171 |
+
def sample_ddim_context(n_sample, context, n=20):
|
172 |
+
# x_T ~ N(0, 1), sample initial noise
|
173 |
+
samples = torch.randn(n_sample, 3, height, height).to(device)
|
174 |
+
|
175 |
+
# array to keep track of generated steps for plotting
|
176 |
+
intermediate = []
|
177 |
+
step_size = timesteps // n
|
178 |
+
for i in range(timesteps, 0, -step_size):
|
179 |
+
print(f'sampling timestep {i:3d}', end='\r')
|
180 |
+
|
181 |
+
# reshape time tensor
|
182 |
+
t = torch.tensor([i / timesteps])[:, None, None, None].to(device)
|
183 |
+
|
184 |
+
eps = nn_model(samples, t, c=context) # predict noise e_(x_t,t)
|
185 |
+
samples = denoise_ddim(samples, i, i - step_size, eps)
|
186 |
+
intermediate.append(samples.detach().cpu().numpy())
|
187 |
+
|
188 |
+
intermediate = np.stack(intermediate)
|
189 |
+
return samples, intermediate
|
190 |
+
|
191 |
def greet(input):
|
192 |
steps = int(input)
|
193 |
+
#samples, intermediate = sample_ddim(32, n=steps)
|
194 |
+
ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()
|
195 |
+
samples, intermediate = sample_ddim_context(32, ctx)
|
196 |
#response = transform2(transform(np.moveaxis(samples.detach().cpu().numpy(),1,3)[-1]))
|
197 |
#response2 = transform2(transform(np.moveaxis(samples.detach().cpu().numpy(),1,3)[1]))
|
198 |
#response = im.fromarray(intermediate[24][0][1]).convert("RGB")
|