|
import ast |
|
from itertools import tee |
|
|
|
from tqdm.auto import trange |
|
|
|
from comfy.k_diffusion import sampling as k_diffusion_sampling |
|
from comfy.ldm.models.diffusion.ddim import DDIMSampler |
|
from comfy.samplers import CompVisVDenoiser |
|
from comfy.utils import ProgressBar |
|
from nodes import common_ksampler |
|
import torch |
|
|
|
from ..utils.refined_exp_solver import RefinedExpCallbackPayload, _refined_exp_sosu_step |
|
from .restart_schedulers import SCHEDULER_MAPPING |
|
|
|
def pairwise(iterable): |
|
"s -> (s0, s1), (s1, s2), (s2, s3), ..." |
|
a, b = tee(iterable) |
|
next(b, None) |
|
return zip(a, b) |
|
|
|
|
|
def add_restart_segment(restart_segments, n_restart, k, t_min, t_max): |
|
if restart_segments is None: |
|
restart_segments = [] |
|
restart_segments.append({'n': n_restart, 'k': k, 't_min': t_min, 't_max': t_max}) |
|
return restart_segments |
|
|
|
|
|
def prepare_restart_segments(restart_info): |
|
try: |
|
restart_arrays = ast.literal_eval(f"[{restart_info}]") |
|
except SyntaxError as e: |
|
print("Ill-formed restart segments") |
|
raise e |
|
restart_segments = [] |
|
for arr in restart_arrays: |
|
if len(arr) != 4: |
|
raise ValueError("Restart segment must have 4 values") |
|
n_restart, k, t_min, t_max = arr |
|
restart_segments = add_restart_segment(restart_segments, n_restart, k, t_min, t_max) |
|
return restart_segments |
|
|
|
|
|
def round_restart_segments(sigmas, restart_segments): |
|
s_min, s_max = sigmas[-1], sigmas[0] |
|
t_min_mapping = {} |
|
for segment in reversed(restart_segments): |
|
if segment['t_max'] > s_max: |
|
continue |
|
t_min_neighbor = min(sigmas, key=lambda s: abs(s - segment['t_min'])).item() |
|
t_min_mapping[t_min_neighbor] = {'n': segment['n'], 'k': segment['k'], 't_max': segment['t_max']} |
|
return t_min_mapping |
|
|
|
|
|
def segments_to_timesteps(restart_segments, model): |
|
timesteps = [] |
|
for segment in restart_segments: |
|
t_min, t_max = model.sigma_to_t(torch.tensor( |
|
[segment['t_min'], segment['t_max']], device=model.log_sigmas.device)) |
|
ts_segment = {'n': segment['n'], 'k': segment['k'], 't_min': t_min, 't_max': t_max} |
|
timesteps.append(ts_segment) |
|
return timesteps |
|
|
|
|
|
def round_restart_segments_timesteps(timesteps, restart_segments): |
|
t_min_mapping = {} |
|
for segment in reversed(restart_segments): |
|
t_min_neighbor = min(timesteps, key=lambda ts: abs(ts - segment['t_min'])).item() |
|
t_min_mapping[t_min_neighbor] = {'n': segment['n'], 'k': segment['k'], 't_max': segment['t_max']} |
|
return t_min_mapping |
|
|
|
|
|
def calc_sigmas(scheduler, n, sigma_min, sigma_max, model, device): |
|
return SCHEDULER_MAPPING[scheduler](model, n, sigma_min, sigma_max, device) |
|
|
|
|
|
def calc_restart_steps(restart_segments): |
|
restart_steps = 0 |
|
for segment in restart_segments.values(): |
|
restart_steps += (segment['n'] - 1) * segment['k'] |
|
return restart_steps |
|
|
|
|
|
_total_steps = 0 |
|
_restart_segments = None |
|
_restart_scheduler = None |
|
|
|
|
|
def restart_sampling(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise, restart_info, restart_scheduler, begin_at_step, end_at_step, disable_noise, force_full_denoise): |
|
global _total_steps, _restart_segments, _restart_scheduler |
|
_restart_scheduler = restart_scheduler |
|
_restart_segments = prepare_restart_segments(restart_info) |
|
|
|
if sampler_name == "res": |
|
sampler_wrapper = RESWrapper(begin_at_step, end_at_step, _restart_segments, _restart_scheduler) |
|
elif sampler_name == "ddim": |
|
sampler_wrapper = DDIMWrapper(begin_at_step, end_at_step, _restart_segments, _restart_scheduler) |
|
else: |
|
sampler_wrapper = KSamplerRestartWrapper(sampler_name, begin_at_step, end_at_step, _restart_segments, _restart_scheduler) |
|
|
|
|
|
pbar_update_absolute = ProgressBar.update_absolute |
|
|
|
def pbar_update_absolute_wrapper(self, value, total=None, preview=None): |
|
pbar_update_absolute(self, value, _total_steps, preview) |
|
|
|
ProgressBar.update_absolute = pbar_update_absolute_wrapper |
|
|
|
try: |
|
samples = common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, |
|
positive, negative, latent_image, denoise=denoise, force_full_denoise=force_full_denoise, disable_noise=disable_noise) |
|
finally: |
|
sampler_wrapper.cleanup() |
|
ProgressBar.update_absolute = pbar_update_absolute |
|
return samples |
|
|
|
|
|
class RestartWrapper: |
|
|
|
def cleanup(self): |
|
pass |
|
|
|
class RESWrapper(RestartWrapper): |
|
def __init__(self, begin_at_step,end_at_step,segments,scheduler,cfg_clamp_after_step=0): |
|
self.__class__.segments = segments |
|
self.__class__.scheduler = scheduler |
|
self.__class__.refiner_stage = False |
|
self.__class__.cfg_clamp_after_step = cfg_clamp_after_step |
|
self.sample_func_name = "sample_res" |
|
setattr(k_diffusion_sampling, self.sample_func_name, self.ksampler_restart_wrapper) |
|
|
|
@staticmethod |
|
@torch.no_grad() |
|
def ksampler_restart_wrapper(model, x, sigmas, extra_args=None, callback=None, disable=None): |
|
global _total_steps, _restart_segments, _restart_scheduler |
|
_restart_scheduler = __class__.scheduler |
|
_restart_segments = __class__.segments |
|
segments = round_restart_segments(sigmas, _restart_segments) |
|
_total_steps = len(sigmas) - 1 + calc_restart_steps(segments) |
|
step = 0 |
|
real_steps = 0 |
|
|
|
def callback_wrapper(x): |
|
x["i"] = step |
|
if callback is not None: |
|
callback(x) |
|
ita: torch.FloatTensor = torch.zeros((1,),device=x.device) |
|
simple_phi_calc = True |
|
c2 = .5 |
|
with trange(_total_steps, disable=disable) as pbar: |
|
for i, (sigma, sigma_next) in enumerate(pairwise(sigmas[:-1].split(1))): |
|
if real_steps > __class__.cfg_clamp_after_step: |
|
extra_args['cond_scale'] = 1.0 |
|
|
|
eps: torch.FloatTensor = torch.randn_like(x,device=x.device) |
|
sigma_hat = sigma * (1 + ita) |
|
x_hat = x + (sigma_hat ** 2 - sigma ** 2) ** .5 * eps |
|
x_next, denoised, denoised2 = _refined_exp_sosu_step( |
|
model, |
|
x_hat, |
|
sigma_hat, |
|
sigma_next, |
|
c2=c2, |
|
extra_args=extra_args, |
|
pbar=pbar, |
|
simple_phi_calc=simple_phi_calc, |
|
) |
|
if callback is not None: |
|
payload = RefinedExpCallbackPayload( |
|
x=x, |
|
i=step, |
|
sigma=sigma, |
|
sigma_hat=sigma_hat, |
|
denoised=denoised, |
|
denoised2=denoised2, |
|
) |
|
callback(payload) |
|
x = x_next |
|
pbar.update(1) |
|
step += 1 |
|
real_steps += 1 |
|
if sigmas[i].item() in segments: |
|
seg = segments[sigmas[i].item()] |
|
s_min, s_max, k, n_restart = sigmas[i+1], seg['t_max'], seg['k'], seg['n'] |
|
seg_sigmas = calc_sigmas(_restart_scheduler, n_restart, s_min, |
|
s_max, model.inner_model, device=x.device) |
|
for _ in range(k): |
|
|
|
for j in range(n_restart - 1): |
|
eps: torch.FloatTensor = torch.randn_like(x,device=x.device) |
|
sigma_hat = seg_sigmas[j] * (1 + ita) |
|
x_hat = x + (sigma_hat ** 2 - seg_sigmas[j] ** 2) ** .5 * eps |
|
x_next, denoised, denoised2 = _refined_exp_sosu_step( |
|
model, |
|
x_hat, |
|
sigma_hat, |
|
seg_sigmas[j+1], |
|
c2=c2, |
|
extra_args=extra_args, |
|
pbar=pbar, |
|
simple_phi_calc=simple_phi_calc, |
|
) |
|
|
|
|
|
|
|
pbar.update(1) |
|
step += 1 |
|
if __class__.refiner_stage: |
|
eps: torch.FloatTensor = torch.randn_like(x,device=x.device) |
|
sigma_hat = sigma * (1 + ita) |
|
x_hat = x + (sigma_hat ** 2 - sigma ** 2) ** .5 * eps |
|
x_next: torch.FloatTensor = model(x_hat, sigma.to(x_hat.device),**extra_args) |
|
pbar.update() |
|
x = x_next |
|
return x |
|
|
|
class KSamplerRestartWrapper(RestartWrapper): |
|
|
|
ksampler = None |
|
|
|
def __init__(self, sampler_name,begin_at_step,end_at_step, segments, scheduler,cfg_clamp_after_step=0): |
|
self.sample_func_name = "sample_{}".format(sampler_name) |
|
self.__class__.segments = segments |
|
self.__class__.ksampler = getattr(k_diffusion_sampling, self.sample_func_name) |
|
self.__class__.begin_at_step = begin_at_step |
|
self.__class__.end_at_step = end_at_step |
|
self.__class__.scheduler = scheduler |
|
self.__class__.refiner_stage = False |
|
self.__class__.original_sigmas = None |
|
self.__class__.total_steps = 0 |
|
self.__class__.continuation_step = 0 |
|
self.__class__.cfg_clamp_after_step = cfg_clamp_after_step |
|
setattr(k_diffusion_sampling, self.sample_func_name, self.ksampler_restart_wrapper) |
|
|
|
def cleanup(self): |
|
setattr(k_diffusion_sampling, self.sample_func_name, KSamplerRestartWrapper.ksampler) |
|
|
|
@staticmethod |
|
@torch.no_grad() |
|
def ksampler_restart_wrapper(model, x, sigmas, extra_args=None, callback=None, disable=None): |
|
global _total_steps, _restart_segments, _restart_scheduler |
|
ksampler = __class__.ksampler |
|
_restart_scheduler = __class__.scheduler |
|
_restart_segments = __class__.segments |
|
begin_at_step = __class__.begin_at_step |
|
end_at_step = __class__.end_at_step |
|
segments = round_restart_segments(sigmas, _restart_segments) |
|
_total_steps = len(sigmas) - 1 + calc_restart_steps(segments) |
|
step = 0 |
|
if not __class__.refiner_stage: |
|
__class__.original_sigmas = sigmas |
|
__class__.total_steps = _total_steps |
|
else: |
|
|
|
step = 0 |
|
|
|
def callback_wrapper(x): |
|
x["i"] = step |
|
if callback is not None: |
|
callback(x) |
|
|
|
real_steps = 0 |
|
with trange(_total_steps, disable=disable) as pbar: |
|
for i in range(len(sigmas) - 1): |
|
if real_steps > __class__.cfg_clamp_after_step: |
|
extra_args['cond_scale'] = 1.0 |
|
|
|
|
|
|
|
|
|
x = ksampler(model, x, torch.tensor([sigmas[i], sigmas[i + 1]], |
|
device=x.device), extra_args, callback_wrapper, True) |
|
pbar.update(1) |
|
step += 1 |
|
real_steps += 1 |
|
if sigmas[i].item() in segments: |
|
seg = segments[sigmas[i].item()] |
|
s_min, s_max, k, n_restart = sigmas[i+1], seg['t_max'], seg['k'], seg['n'] |
|
seg_sigmas = calc_sigmas(_restart_scheduler, n_restart, s_min, |
|
s_max, model.inner_model, device=x.device) |
|
for _ in range(k): |
|
x += torch.randn_like(x) * (s_max ** 2 - s_min ** 2) ** 0.5 |
|
for j in range(n_restart - 1): |
|
|
|
x = ksampler(model, x, torch.tensor( |
|
[seg_sigmas[j], seg_sigmas[j + 1]], device=x.device), extra_args, callback_wrapper, True) |
|
pbar.update(1) |
|
step += 1 |
|
return x |
|
|
|
|
|
class DDIMWrapper(RestartWrapper): |
|
|
|
def __init__(self,begin_at_step,end_at_step,segments,scheduler,cfg_clamp_after_step=0): |
|
self.__class__.sample_custom = DDIMSampler.sample_custom |
|
self.__class__.begin_at_step = begin_at_step |
|
self.__class__.end_at_step = end_at_step |
|
self.__class__.segments = segments |
|
self.__class__.scheduler = scheduler |
|
self.__class__.refiner_stage = False |
|
self.__class__.cfg_clamp_after_step = cfg_clamp_after_step |
|
DDIMSampler.sample_custom = self.ddim_wrapper |
|
|
|
def cleanup(self): |
|
DDIMSampler.sample_custom = self.__class__.sample_custom |
|
|
|
@staticmethod |
|
@torch.no_grad() |
|
def ddim_wrapper(self, ddim_timesteps, conditioning, callback=None, img_callback=None, quantize_x0=False, |
|
eta=0., mask=None, x0=None, temperature=1., noise_dropout=0., score_corrector=None, |
|
corrector_kwargs=None, verbose=True, x_T=None, log_every_t=100, unconditional_guidance_scale=1., |
|
unconditional_conditioning=None, dynamic_threshold=None, ucg_schedule=None, denoise_function=None, |
|
extra_args=None, to_zero=True, end_step=None, disable_pbar=False, **kwargs): |
|
global _total_steps, _restart_segments, _restart_scheduler |
|
ddim_sampler = __class__.sample_custom |
|
model_denoise = CompVisVDenoiser(self.model) |
|
|
|
begin_at_step = __class__.begin_at_step |
|
end_at_step = __class__.end_at_step |
|
_restart_segments = __class__.segments |
|
_restart_scheduler = __class__.scheduler |
|
segments = segments_to_timesteps(_restart_segments, model_denoise) |
|
segments = round_restart_segments_timesteps(ddim_timesteps, segments) |
|
_total_steps = len(ddim_timesteps) - 1 + calc_restart_steps(segments) |
|
step = 0 |
|
|
|
def callback_wrapper(pred_x0, i): |
|
img_callback(pred_x0, step) |
|
|
|
def ddim_simplified(x, timesteps, x_T=None, disable_pbar=False): |
|
if x_T is None: |
|
self.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False) |
|
x_T = self.stochastic_encode(x, torch.tensor( |
|
[len(timesteps) - 1] * x.shape[0]).to(self.device), noise=torch.zeros_like(x), max_denoise=False) |
|
x, intermediates = ddim_sampler( |
|
self, timesteps, conditioning, callback=callback, img_callback=callback_wrapper, quantize_x0=quantize_x0, |
|
eta=eta, mask=mask, x0=x, temperature=temperature, noise_dropout=noise_dropout, score_corrector=score_corrector, |
|
corrector_kwargs=corrector_kwargs, verbose=verbose, x_T=x_T, log_every_t=log_every_t, |
|
unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, |
|
dynamic_threshold=dynamic_threshold, ucg_schedule=ucg_schedule, denoise_function=denoise_function, extra_args=extra_args, |
|
to_zero=timesteps[0].item() == 0, end_step=len(timesteps) - 1, disable_pbar=disable_pbar |
|
) |
|
return x, intermediates |
|
intermediates = None |
|
real_steps = 0 |
|
with trange(_total_steps, disable=disable_pbar) as pbar: |
|
|
|
rg = reversed(range( |
|
len(ddim_timesteps) - 1 - min(len(ddim_timesteps) - 1,end_at_step+1), |
|
len(ddim_timesteps) - max(begin_at_step,0) |
|
) |
|
) |
|
for i in rg: |
|
if real_steps > __class__.cfg_clamp_after_step: |
|
extra_args['cond_scale'] = 1.0 |
|
|
|
x0, intermediates = ddim_simplified(x0, ddim_timesteps[i:i + 2], x_T=x_T, disable_pbar=True) |
|
x_T = None |
|
pbar.update(1) |
|
step += 1 |
|
real_steps += 1 |
|
if ddim_timesteps[i].item() in segments: |
|
seg = segments[ddim_timesteps[i].item()] |
|
t_min, t_max, k, n_restart = ddim_timesteps[i], seg['t_max'], seg['k'], seg['n'] |
|
s_min, s_max = model_denoise.t_to_sigma(t_min), model_denoise.t_to_sigma(t_max) |
|
seg_sigmas = calc_sigmas(_restart_scheduler, n_restart, s_min, |
|
s_max, model_denoise, device=x0.device) |
|
for _ in range(k): |
|
x0 += torch.randn_like(x0) * (s_max ** 2 - s_min ** 2) ** 0.5 |
|
for j in range(n_restart - 1): |
|
seg_ts = model_denoise.sigma_to_t(seg_sigmas[j]).to(torch.int32) |
|
seg_ts_next = model_denoise.sigma_to_t(seg_sigmas[j + 1]).to(torch.int32) |
|
x0, intermediates = ddim_simplified(x0, [seg_ts_next, seg_ts], disable_pbar=True) |
|
pbar.update(1) |
|
step += 1 |
|
return x0, intermediates |
|
|