Tony Lian commited on
Commit
b1ea54b
1 Parent(s): 363907c

Use fast_after_steps only with use_fast_schedule

Browse files
Files changed (1) hide show
  1. utils/latents.py +4 -1
utils/latents.py CHANGED
@@ -74,7 +74,10 @@ def compose_latents(model_dict, latents_all_list, mask_tensor_list, num_inferenc
74
  latents_all, mask_tensor = latents_all_list[mask_idx], mask_tensor_list[mask_idx]
75
  foreground_indices = foreground_indices * (~mask_tensor) + (mask_idx + 1) * mask_tensor
76
  mask_tensor_expanded = mask_tensor[None, None, None, ...].to(dtype)
77
- composed_latents = composed_latents * (1. - mask_tensor_expanded) + latents_all[:fast_after_steps + 1] * mask_tensor_expanded
 
 
 
78
 
79
  composed_latents, foreground_indices = composed_latents.to(torch_device), foreground_indices.to(torch_device)
80
  return composed_latents, foreground_indices
 
74
  latents_all, mask_tensor = latents_all_list[mask_idx], mask_tensor_list[mask_idx]
75
  foreground_indices = foreground_indices * (~mask_tensor) + (mask_idx + 1) * mask_tensor
76
  mask_tensor_expanded = mask_tensor[None, None, None, ...].to(dtype)
77
+ if use_fast_schedule:
78
+ composed_latents = composed_latents * (1. - mask_tensor_expanded) + latents_all[:fast_after_steps + 1] * mask_tensor_expanded
79
+ else:
80
+ composed_latents = composed_latents * (1. - mask_tensor_expanded) + latents_all * mask_tensor_expanded
81
 
82
  composed_latents, foreground_indices = composed_latents.to(torch_device), foreground_indices.to(torch_device)
83
  return composed_latents, foreground_indices