harry900000 commited on
Commit
6ded867
·
1 Parent(s): 541c8d3

use tqdm to track the current step

Browse files
cosmos_transfer1/diffusion/diffusion/modules/res_sampler.py CHANGED
@@ -27,6 +27,7 @@ from typing import Any, Callable, List, Literal, Optional, Tuple, Union
27
 
28
  import attrs
29
  import torch
 
30
 
31
  from cosmos_transfer1.diffusion.diffusion.functional.multi_step import get_multi_step_fn, is_multi_step_fn_supported
32
  from cosmos_transfer1.diffusion.diffusion.functional.runge_kutta import get_runge_kutta_fn, is_runge_kutta_fn_supported
@@ -204,7 +205,7 @@ def fori_loop(lower: int, upper: int, body_fun: Callable[[int, Any], Any], init_
204
  The final result after all iterations.
205
  """
206
  val = init_val
207
- for i in range(lower, upper):
208
  val = body_fun(i, val)
209
  return val
210
 
@@ -251,7 +252,7 @@ def differential_equation_solver(
251
  def step_fn(
252
  i_th: int, state: Tuple[torch.Tensor, Optional[List[torch.Tensor]]]
253
  ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
254
- log.info(f"Step [{i_th}/{num_step}]")
255
  input_x_B_StateShape, x0_preds = state
256
  sigma_cur_0, sigma_next_0 = sigmas_L[i_th], sigmas_L[i_th + 1]
257
 
 
27
 
28
  import attrs
29
  import torch
30
+ from tqdm import tqdm
31
 
32
  from cosmos_transfer1.diffusion.diffusion.functional.multi_step import get_multi_step_fn, is_multi_step_fn_supported
33
  from cosmos_transfer1.diffusion.diffusion.functional.runge_kutta import get_runge_kutta_fn, is_runge_kutta_fn_supported
 
205
  The final result after all iterations.
206
  """
207
  val = init_val
208
+ for i in tqdm(range(lower, upper)):
209
  val = body_fun(i, val)
210
  return val
211
 
 
252
  def step_fn(
253
  i_th: int, state: Tuple[torch.Tensor, Optional[List[torch.Tensor]]]
254
  ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
255
+ log.debug(f"Step [{i_th}/{num_step}]")
256
  input_x_B_StateShape, x0_preds = state
257
  sigma_cur_0, sigma_next_0 = sigmas_L[i_th], sigmas_L[i_th + 1]
258