import ast from itertools import tee from tqdm.auto import trange from comfy.k_diffusion import sampling as k_diffusion_sampling from comfy.ldm.models.diffusion.ddim import DDIMSampler from comfy.samplers import CompVisVDenoiser from comfy.utils import ProgressBar from nodes import common_ksampler import torch from ..utils.refined_exp_solver import RefinedExpCallbackPayload, _refined_exp_sosu_step from .restart_schedulers import SCHEDULER_MAPPING def pairwise(iterable): "s -> (s0, s1), (s1, s2), (s2, s3), ..." a, b = tee(iterable) next(b, None) return zip(a, b) def add_restart_segment(restart_segments, n_restart, k, t_min, t_max): if restart_segments is None: restart_segments = [] restart_segments.append({'n': n_restart, 'k': k, 't_min': t_min, 't_max': t_max}) return restart_segments def prepare_restart_segments(restart_info): try: restart_arrays = ast.literal_eval(f"[{restart_info}]") except SyntaxError as e: print("Ill-formed restart segments") raise e restart_segments = [] for arr in restart_arrays: if len(arr) != 4: raise ValueError("Restart segment must have 4 values") n_restart, k, t_min, t_max = arr restart_segments = add_restart_segment(restart_segments, n_restart, k, t_min, t_max) return restart_segments def round_restart_segments(sigmas, restart_segments): s_min, s_max = sigmas[-1], sigmas[0] t_min_mapping = {} for segment in reversed(restart_segments): # Reversed to prioritize segments to the front if segment['t_max'] > s_max: continue #toss the segment t_min_neighbor = min(sigmas, key=lambda s: abs(s - segment['t_min'])).item() t_min_mapping[t_min_neighbor] = {'n': segment['n'], 'k': segment['k'], 't_max': segment['t_max']} return t_min_mapping def segments_to_timesteps(restart_segments, model): timesteps = [] for segment in restart_segments: t_min, t_max = model.sigma_to_t(torch.tensor( [segment['t_min'], segment['t_max']], device=model.log_sigmas.device)) ts_segment = {'n': segment['n'], 'k': segment['k'], 't_min': t_min, 't_max': t_max} timesteps.append(ts_segment) return timesteps def round_restart_segments_timesteps(timesteps, restart_segments): t_min_mapping = {} for segment in reversed(restart_segments): # Reversed to prioritize segments to the front t_min_neighbor = min(timesteps, key=lambda ts: abs(ts - segment['t_min'])).item() t_min_mapping[t_min_neighbor] = {'n': segment['n'], 'k': segment['k'], 't_max': segment['t_max']} return t_min_mapping def calc_sigmas(scheduler, n, sigma_min, sigma_max, model, device): return SCHEDULER_MAPPING[scheduler](model, n, sigma_min, sigma_max, device) def calc_restart_steps(restart_segments): restart_steps = 0 for segment in restart_segments.values(): restart_steps += (segment['n'] - 1) * segment['k'] return restart_steps _total_steps = 0 _restart_segments = None _restart_scheduler = None def restart_sampling(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise, restart_info, restart_scheduler, begin_at_step, end_at_step, disable_noise, force_full_denoise): global _total_steps, _restart_segments, _restart_scheduler _restart_scheduler = restart_scheduler _restart_segments = prepare_restart_segments(restart_info) if sampler_name == "res": sampler_wrapper = RESWrapper(begin_at_step, end_at_step, _restart_segments, _restart_scheduler) elif sampler_name == "ddim": sampler_wrapper = DDIMWrapper(begin_at_step, end_at_step, _restart_segments, _restart_scheduler) else: sampler_wrapper = KSamplerRestartWrapper(sampler_name, begin_at_step, end_at_step, _restart_segments, _restart_scheduler) # Add the additional steps to the progress bar pbar_update_absolute = ProgressBar.update_absolute def pbar_update_absolute_wrapper(self, value, total=None, preview=None): pbar_update_absolute(self, value, _total_steps, preview) ProgressBar.update_absolute = pbar_update_absolute_wrapper try: samples = common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, force_full_denoise=force_full_denoise, disable_noise=disable_noise) finally: sampler_wrapper.cleanup() ProgressBar.update_absolute = pbar_update_absolute return samples class RestartWrapper: def cleanup(self): pass class RESWrapper(RestartWrapper): def __init__(self, begin_at_step,end_at_step,segments,scheduler,cfg_clamp_after_step=0): self.__class__.segments = segments self.__class__.scheduler = scheduler self.__class__.refiner_stage = False self.__class__.cfg_clamp_after_step = cfg_clamp_after_step self.sample_func_name = "sample_res" setattr(k_diffusion_sampling, self.sample_func_name, self.ksampler_restart_wrapper) @staticmethod @torch.no_grad() def ksampler_restart_wrapper(model, x, sigmas, extra_args=None, callback=None, disable=None): global _total_steps, _restart_segments, _restart_scheduler _restart_scheduler = __class__.scheduler _restart_segments = __class__.segments segments = round_restart_segments(sigmas, _restart_segments) _total_steps = len(sigmas) - 1 + calc_restart_steps(segments) step = 0 real_steps = 0 def callback_wrapper(x): x["i"] = step if callback is not None: callback(x) ita: torch.FloatTensor = torch.zeros((1,),device=x.device) simple_phi_calc = True c2 = .5 with trange(_total_steps, disable=disable) as pbar: for i, (sigma, sigma_next) in enumerate(pairwise(sigmas[:-1].split(1))): if real_steps > __class__.cfg_clamp_after_step: extra_args['cond_scale'] = 1.0 eps: torch.FloatTensor = torch.randn_like(x,device=x.device) sigma_hat = sigma * (1 + ita) x_hat = x + (sigma_hat ** 2 - sigma ** 2) ** .5 * eps x_next, denoised, denoised2 = _refined_exp_sosu_step( model, x_hat, sigma_hat, sigma_next, c2=c2, extra_args=extra_args, pbar=pbar, simple_phi_calc=simple_phi_calc, ) if callback is not None: payload = RefinedExpCallbackPayload( x=x, i=step, sigma=sigma, sigma_hat=sigma_hat, denoised=denoised, denoised2=denoised2, ) callback(payload) x = x_next pbar.update(1) step += 1 real_steps += 1 if sigmas[i].item() in segments: seg = segments[sigmas[i].item()] s_min, s_max, k, n_restart = sigmas[i+1], seg['t_max'], seg['k'], seg['n'] seg_sigmas = calc_sigmas(_restart_scheduler, n_restart, s_min, s_max, model.inner_model, device=x.device) for _ in range(k): #x += torch.randn_like(x) * (s_max ** 2 - s_min ** 2) ** 0.5 for j in range(n_restart - 1): eps: torch.FloatTensor = torch.randn_like(x,device=x.device) sigma_hat = seg_sigmas[j] * (1 + ita) x_hat = x + (sigma_hat ** 2 - seg_sigmas[j] ** 2) ** .5 * eps x_next, denoised, denoised2 = _refined_exp_sosu_step( model, x_hat, sigma_hat, seg_sigmas[j+1], c2=c2, extra_args=extra_args, pbar=pbar, simple_phi_calc=simple_phi_calc, ) # x = sample_refined_exp_s(model,x,torch.tensor([sigmas[j], sigmas[j + 1]], device=x.device),extra_args=extra_args,callback=callback_wrapper,disable=True) # x = ksampler(model, x, torch.tensor( # [seg_sigmas[j], seg_sigmas[j + 1]], device=x.device), extra_args, callback_wrapper, True) pbar.update(1) step += 1 if __class__.refiner_stage: eps: torch.FloatTensor = torch.randn_like(x,device=x.device) sigma_hat = sigma * (1 + ita) x_hat = x + (sigma_hat ** 2 - sigma ** 2) ** .5 * eps x_next: torch.FloatTensor = model(x_hat, sigma.to(x_hat.device),**extra_args) pbar.update() x = x_next return x class KSamplerRestartWrapper(RestartWrapper): ksampler = None def __init__(self, sampler_name,begin_at_step,end_at_step, segments, scheduler,cfg_clamp_after_step=0): self.sample_func_name = "sample_{}".format(sampler_name) self.__class__.segments = segments self.__class__.ksampler = getattr(k_diffusion_sampling, self.sample_func_name) self.__class__.begin_at_step = begin_at_step self.__class__.end_at_step = end_at_step self.__class__.scheduler = scheduler self.__class__.refiner_stage = False self.__class__.original_sigmas = None self.__class__.total_steps = 0 self.__class__.continuation_step = 0 self.__class__.cfg_clamp_after_step = cfg_clamp_after_step setattr(k_diffusion_sampling, self.sample_func_name, self.ksampler_restart_wrapper) def cleanup(self): setattr(k_diffusion_sampling, self.sample_func_name, KSamplerRestartWrapper.ksampler) @staticmethod @torch.no_grad() def ksampler_restart_wrapper(model, x, sigmas, extra_args=None, callback=None, disable=None): global _total_steps, _restart_segments, _restart_scheduler ksampler = __class__.ksampler _restart_scheduler = __class__.scheduler _restart_segments = __class__.segments begin_at_step = __class__.begin_at_step end_at_step = __class__.end_at_step segments = round_restart_segments(sigmas, _restart_segments) _total_steps = len(sigmas) - 1 + calc_restart_steps(segments) step = 0 if not __class__.refiner_stage: __class__.original_sigmas = sigmas __class__.total_steps = _total_steps else: # Calculate new begin_at_step and end_at_step based on number of steps already completed step = 0 def callback_wrapper(x): x["i"] = step if callback is not None: callback(x) real_steps = 0 with trange(_total_steps, disable=disable) as pbar: for i in range(len(sigmas) - 1): if real_steps > __class__.cfg_clamp_after_step: extra_args['cond_scale'] = 1.0 # if i+1 > end_at_step: # __class__.continuation_step = i # break # x = sample_refined_exp_s(model,x,torch.tensor([sigmas[i], sigmas[i + 1]], device=x.device),extra_args=extra_args,callback=callback_wrapper,disable=True) x = ksampler(model, x, torch.tensor([sigmas[i], sigmas[i + 1]], device=x.device), extra_args, callback_wrapper, True) pbar.update(1) step += 1 real_steps += 1 if sigmas[i].item() in segments: seg = segments[sigmas[i].item()] s_min, s_max, k, n_restart = sigmas[i+1], seg['t_max'], seg['k'], seg['n'] seg_sigmas = calc_sigmas(_restart_scheduler, n_restart, s_min, s_max, model.inner_model, device=x.device) for _ in range(k): x += torch.randn_like(x) * (s_max ** 2 - s_min ** 2) ** 0.5 for j in range(n_restart - 1): # x = sample_refined_exp_s(model,x,torch.tensor([sigmas[j], sigmas[j + 1]], device=x.device),extra_args=extra_args,callback=callback_wrapper,disable=True) x = ksampler(model, x, torch.tensor( [seg_sigmas[j], seg_sigmas[j + 1]], device=x.device), extra_args, callback_wrapper, True) pbar.update(1) step += 1 return x class DDIMWrapper(RestartWrapper): def __init__(self,begin_at_step,end_at_step,segments,scheduler,cfg_clamp_after_step=0): self.__class__.sample_custom = DDIMSampler.sample_custom self.__class__.begin_at_step = begin_at_step self.__class__.end_at_step = end_at_step self.__class__.segments = segments self.__class__.scheduler = scheduler self.__class__.refiner_stage = False self.__class__.cfg_clamp_after_step = cfg_clamp_after_step DDIMSampler.sample_custom = self.ddim_wrapper def cleanup(self): DDIMSampler.sample_custom = self.__class__.sample_custom @staticmethod @torch.no_grad() def ddim_wrapper(self, ddim_timesteps, conditioning, callback=None, img_callback=None, quantize_x0=False, eta=0., mask=None, x0=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, verbose=True, x_T=None, log_every_t=100, unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None, disable_pbar=False, **kwargs): global _total_steps, _restart_segments, _restart_scheduler ddim_sampler = __class__.sample_custom model_denoise = CompVisVDenoiser(self.model) begin_at_step = __class__.begin_at_step end_at_step = __class__.end_at_step _restart_segments = __class__.segments _restart_scheduler = __class__.scheduler segments = segments_to_timesteps(_restart_segments, model_denoise) segments = round_restart_segments_timesteps(ddim_timesteps, segments) _total_steps = len(ddim_timesteps) - 1 + calc_restart_steps(segments) step = 0 def callback_wrapper(pred_x0, i): img_callback(pred_x0, step) def ddim_simplified(x, timesteps, x_T=None, disable_pbar=False): if x_T is None: self.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False) x_T = self.stochastic_encode(x, torch.tensor( [len(timesteps) - 1] * x.shape[0]).to(self.device), noise=torch.zeros_like(x), max_denoise=False) x, intermediates = ddim_sampler( self, timesteps, conditioning, callback=callback, img_callback=callback_wrapper, quantize_x0=quantize_x0, eta=eta, mask=mask, x0=x, temperature=temperature, noise_dropout=noise_dropout, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, verbose=verbose, x_T=x_T, log_every_t=log_every_t, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, dynamic_threshold=dynamic_threshold, ucg_schedule=ucg_schedule, denoise_function=denoise_function, extra_args=extra_args, to_zero=timesteps[0].item() == 0, end_step=len(timesteps) - 1, disable_pbar=disable_pbar ) return x, intermediates intermediates = None real_steps = 0 with trange(_total_steps, disable=disable_pbar) as pbar: rg = reversed(range( len(ddim_timesteps) - 1 - min(len(ddim_timesteps) - 1,end_at_step+1), #start len(ddim_timesteps) - max(begin_at_step,0) #end ) ) for i in rg: if real_steps > __class__.cfg_clamp_after_step: extra_args['cond_scale'] = 1.0 x0, intermediates = ddim_simplified(x0, ddim_timesteps[i:i + 2], x_T=x_T, disable_pbar=True) x_T = None pbar.update(1) step += 1 real_steps += 1 if ddim_timesteps[i].item() in segments: seg = segments[ddim_timesteps[i].item()] t_min, t_max, k, n_restart = ddim_timesteps[i], seg['t_max'], seg['k'], seg['n'] s_min, s_max = model_denoise.t_to_sigma(t_min), model_denoise.t_to_sigma(t_max) seg_sigmas = calc_sigmas(_restart_scheduler, n_restart, s_min, s_max, model_denoise, device=x0.device) for _ in range(k): x0 += torch.randn_like(x0) * (s_max ** 2 - s_min ** 2) ** 0.5 for j in range(n_restart - 1): seg_ts = model_denoise.sigma_to_t(seg_sigmas[j]).to(torch.int32) seg_ts_next = model_denoise.sigma_to_t(seg_sigmas[j + 1]).to(torch.int32) x0, intermediates = ddim_simplified(x0, [seg_ts_next, seg_ts], disable_pbar=True) pbar.update(1) step += 1 return x0, intermediates