Commit
·
6ded867
1
Parent(s):
541c8d3
use tqdm to track the current step
Browse files
cosmos_transfer1/diffusion/diffusion/modules/res_sampler.py
CHANGED
@@ -27,6 +27,7 @@ from typing import Any, Callable, List, Literal, Optional, Tuple, Union
|
|
27 |
|
28 |
import attrs
|
29 |
import torch
|
|
|
30 |
|
31 |
from cosmos_transfer1.diffusion.diffusion.functional.multi_step import get_multi_step_fn, is_multi_step_fn_supported
|
32 |
from cosmos_transfer1.diffusion.diffusion.functional.runge_kutta import get_runge_kutta_fn, is_runge_kutta_fn_supported
|
@@ -204,7 +205,7 @@ def fori_loop(lower: int, upper: int, body_fun: Callable[[int, Any], Any], init_
|
|
204 |
The final result after all iterations.
|
205 |
"""
|
206 |
val = init_val
|
207 |
-
for i in range(lower, upper):
|
208 |
val = body_fun(i, val)
|
209 |
return val
|
210 |
|
@@ -251,7 +252,7 @@ def differential_equation_solver(
|
|
251 |
def step_fn(
|
252 |
i_th: int, state: Tuple[torch.Tensor, Optional[List[torch.Tensor]]]
|
253 |
) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
|
254 |
-
log.
|
255 |
input_x_B_StateShape, x0_preds = state
|
256 |
sigma_cur_0, sigma_next_0 = sigmas_L[i_th], sigmas_L[i_th + 1]
|
257 |
|
|
|
27 |
|
28 |
import attrs
|
29 |
import torch
|
30 |
+
from tqdm import tqdm
|
31 |
|
32 |
from cosmos_transfer1.diffusion.diffusion.functional.multi_step import get_multi_step_fn, is_multi_step_fn_supported
|
33 |
from cosmos_transfer1.diffusion.diffusion.functional.runge_kutta import get_runge_kutta_fn, is_runge_kutta_fn_supported
|
|
|
205 |
The final result after all iterations.
|
206 |
"""
|
207 |
val = init_val
|
208 |
+
for i in tqdm(range(lower, upper)):
|
209 |
val = body_fun(i, val)
|
210 |
return val
|
211 |
|
|
|
252 |
def step_fn(
|
253 |
i_th: int, state: Tuple[torch.Tensor, Optional[List[torch.Tensor]]]
|
254 |
) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
|
255 |
+
log.debug(f"Step [{i_th}/{num_step}]")
|
256 |
input_x_B_StateShape, x0_preds = state
|
257 |
sigma_cur_0, sigma_next_0 = sigmas_L[i_th], sigmas_L[i_th + 1]
|
258 |
|