Spaces:
Runtime error
Runtime error
n_sample memory collect
Browse files
gradio_utils/motionctrl_cmcm_gradio.py
CHANGED
@@ -184,10 +184,11 @@ def motionctrl_sample(
|
|
184 |
model.en_and_decode_n_samples_a_time = decoding_t
|
185 |
samples_x = model.decode_first_stage(samples_z)
|
186 |
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) # [1*t, c, h, w]
|
|
|
187 |
results.append(samples)
|
188 |
|
189 |
samples = torch.stack(results, dim=0) # [sample_num, t, c, h, w]
|
190 |
-
samples = samples.data.cpu()
|
191 |
|
192 |
video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
|
193 |
save_results(samples, video_path, fps=save_fps)
|
|
|
184 |
model.en_and_decode_n_samples_a_time = decoding_t
|
185 |
samples_x = model.decode_first_stage(samples_z)
|
186 |
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) # [1*t, c, h, w]
|
187 |
+
samples = samples.data.cpu()
|
188 |
results.append(samples)
|
189 |
|
190 |
samples = torch.stack(results, dim=0) # [sample_num, t, c, h, w]
|
191 |
+
# samples = samples.data.cpu()
|
192 |
|
193 |
video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
|
194 |
save_results(samples, video_path, fps=save_fps)
|