azkavyro's picture
Added all files including vyro_workflows
6fecfbe
import ast
from itertools import tee
from tqdm.auto import trange
from comfy.k_diffusion import sampling as k_diffusion_sampling
from comfy.ldm.models.diffusion.ddim import DDIMSampler
from comfy.samplers import CompVisVDenoiser
from comfy.utils import ProgressBar
from nodes import common_ksampler
import torch
from ..utils.refined_exp_solver import RefinedExpCallbackPayload, _refined_exp_sosu_step
from .restart_schedulers import SCHEDULER_MAPPING
def pairwise(iterable):
"s -> (s0, s1), (s1, s2), (s2, s3), ..."
a, b = tee(iterable)
next(b, None)
return zip(a, b)
def add_restart_segment(restart_segments, n_restart, k, t_min, t_max):
if restart_segments is None:
restart_segments = []
restart_segments.append({'n': n_restart, 'k': k, 't_min': t_min, 't_max': t_max})
return restart_segments
def prepare_restart_segments(restart_info):
try:
restart_arrays = ast.literal_eval(f"[{restart_info}]")
except SyntaxError as e:
print("Ill-formed restart segments")
raise e
restart_segments = []
for arr in restart_arrays:
if len(arr) != 4:
raise ValueError("Restart segment must have 4 values")
n_restart, k, t_min, t_max = arr
restart_segments = add_restart_segment(restart_segments, n_restart, k, t_min, t_max)
return restart_segments
def round_restart_segments(sigmas, restart_segments):
s_min, s_max = sigmas[-1], sigmas[0]
t_min_mapping = {}
for segment in reversed(restart_segments): # Reversed to prioritize segments to the front
if segment['t_max'] > s_max:
continue #toss the segment
t_min_neighbor = min(sigmas, key=lambda s: abs(s - segment['t_min'])).item()
t_min_mapping[t_min_neighbor] = {'n': segment['n'], 'k': segment['k'], 't_max': segment['t_max']}
return t_min_mapping
def segments_to_timesteps(restart_segments, model):
timesteps = []
for segment in restart_segments:
t_min, t_max = model.sigma_to_t(torch.tensor(
[segment['t_min'], segment['t_max']], device=model.log_sigmas.device))
ts_segment = {'n': segment['n'], 'k': segment['k'], 't_min': t_min, 't_max': t_max}
timesteps.append(ts_segment)
return timesteps
def round_restart_segments_timesteps(timesteps, restart_segments):
t_min_mapping = {}
for segment in reversed(restart_segments): # Reversed to prioritize segments to the front
t_min_neighbor = min(timesteps, key=lambda ts: abs(ts - segment['t_min'])).item()
t_min_mapping[t_min_neighbor] = {'n': segment['n'], 'k': segment['k'], 't_max': segment['t_max']}
return t_min_mapping
def calc_sigmas(scheduler, n, sigma_min, sigma_max, model, device):
return SCHEDULER_MAPPING[scheduler](model, n, sigma_min, sigma_max, device)
def calc_restart_steps(restart_segments):
restart_steps = 0
for segment in restart_segments.values():
restart_steps += (segment['n'] - 1) * segment['k']
return restart_steps
_total_steps = 0
_restart_segments = None
_restart_scheduler = None
def restart_sampling(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise, restart_info, restart_scheduler, begin_at_step, end_at_step, disable_noise, force_full_denoise):
global _total_steps, _restart_segments, _restart_scheduler
_restart_scheduler = restart_scheduler
_restart_segments = prepare_restart_segments(restart_info)
if sampler_name == "res":
sampler_wrapper = RESWrapper(begin_at_step, end_at_step, _restart_segments, _restart_scheduler)
elif sampler_name == "ddim":
sampler_wrapper = DDIMWrapper(begin_at_step, end_at_step, _restart_segments, _restart_scheduler)
else:
sampler_wrapper = KSamplerRestartWrapper(sampler_name, begin_at_step, end_at_step, _restart_segments, _restart_scheduler)
# Add the additional steps to the progress bar
pbar_update_absolute = ProgressBar.update_absolute
def pbar_update_absolute_wrapper(self, value, total=None, preview=None):
pbar_update_absolute(self, value, _total_steps, preview)
ProgressBar.update_absolute = pbar_update_absolute_wrapper
try:
samples = common_ksampler(model, seed, steps, cfg, sampler_name, scheduler,
positive, negative, latent_image, denoise=denoise, force_full_denoise=force_full_denoise, disable_noise=disable_noise)
finally:
sampler_wrapper.cleanup()
ProgressBar.update_absolute = pbar_update_absolute
return samples
class RestartWrapper:
def cleanup(self):
pass
class RESWrapper(RestartWrapper):
def __init__(self, begin_at_step,end_at_step,segments,scheduler,cfg_clamp_after_step=0):
self.__class__.segments = segments
self.__class__.scheduler = scheduler
self.__class__.refiner_stage = False
self.__class__.cfg_clamp_after_step = cfg_clamp_after_step
self.sample_func_name = "sample_res"
setattr(k_diffusion_sampling, self.sample_func_name, self.ksampler_restart_wrapper)
@staticmethod
@torch.no_grad()
def ksampler_restart_wrapper(model, x, sigmas, extra_args=None, callback=None, disable=None):
global _total_steps, _restart_segments, _restart_scheduler
_restart_scheduler = __class__.scheduler
_restart_segments = __class__.segments
segments = round_restart_segments(sigmas, _restart_segments)
_total_steps = len(sigmas) - 1 + calc_restart_steps(segments)
step = 0
real_steps = 0
def callback_wrapper(x):
x["i"] = step
if callback is not None:
callback(x)
ita: torch.FloatTensor = torch.zeros((1,),device=x.device)
simple_phi_calc = True
c2 = .5
with trange(_total_steps, disable=disable) as pbar:
for i, (sigma, sigma_next) in enumerate(pairwise(sigmas[:-1].split(1))):
if real_steps > __class__.cfg_clamp_after_step:
extra_args['cond_scale'] = 1.0
eps: torch.FloatTensor = torch.randn_like(x,device=x.device)
sigma_hat = sigma * (1 + ita)
x_hat = x + (sigma_hat ** 2 - sigma ** 2) ** .5 * eps
x_next, denoised, denoised2 = _refined_exp_sosu_step(
model,
x_hat,
sigma_hat,
sigma_next,
c2=c2,
extra_args=extra_args,
pbar=pbar,
simple_phi_calc=simple_phi_calc,
)
if callback is not None:
payload = RefinedExpCallbackPayload(
x=x,
i=step,
sigma=sigma,
sigma_hat=sigma_hat,
denoised=denoised,
denoised2=denoised2,
)
callback(payload)
x = x_next
pbar.update(1)
step += 1
real_steps += 1
if sigmas[i].item() in segments:
seg = segments[sigmas[i].item()]
s_min, s_max, k, n_restart = sigmas[i+1], seg['t_max'], seg['k'], seg['n']
seg_sigmas = calc_sigmas(_restart_scheduler, n_restart, s_min,
s_max, model.inner_model, device=x.device)
for _ in range(k):
#x += torch.randn_like(x) * (s_max ** 2 - s_min ** 2) ** 0.5
for j in range(n_restart - 1):
eps: torch.FloatTensor = torch.randn_like(x,device=x.device)
sigma_hat = seg_sigmas[j] * (1 + ita)
x_hat = x + (sigma_hat ** 2 - seg_sigmas[j] ** 2) ** .5 * eps
x_next, denoised, denoised2 = _refined_exp_sosu_step(
model,
x_hat,
sigma_hat,
seg_sigmas[j+1],
c2=c2,
extra_args=extra_args,
pbar=pbar,
simple_phi_calc=simple_phi_calc,
)
# x = sample_refined_exp_s(model,x,torch.tensor([sigmas[j], sigmas[j + 1]], device=x.device),extra_args=extra_args,callback=callback_wrapper,disable=True)
# x = ksampler(model, x, torch.tensor(
# [seg_sigmas[j], seg_sigmas[j + 1]], device=x.device), extra_args, callback_wrapper, True)
pbar.update(1)
step += 1
if __class__.refiner_stage:
eps: torch.FloatTensor = torch.randn_like(x,device=x.device)
sigma_hat = sigma * (1 + ita)
x_hat = x + (sigma_hat ** 2 - sigma ** 2) ** .5 * eps
x_next: torch.FloatTensor = model(x_hat, sigma.to(x_hat.device),**extra_args)
pbar.update()
x = x_next
return x
class KSamplerRestartWrapper(RestartWrapper):
ksampler = None
def __init__(self, sampler_name,begin_at_step,end_at_step, segments, scheduler,cfg_clamp_after_step=0):
self.sample_func_name = "sample_{}".format(sampler_name)
self.__class__.segments = segments
self.__class__.ksampler = getattr(k_diffusion_sampling, self.sample_func_name)
self.__class__.begin_at_step = begin_at_step
self.__class__.end_at_step = end_at_step
self.__class__.scheduler = scheduler
self.__class__.refiner_stage = False
self.__class__.original_sigmas = None
self.__class__.total_steps = 0
self.__class__.continuation_step = 0
self.__class__.cfg_clamp_after_step = cfg_clamp_after_step
setattr(k_diffusion_sampling, self.sample_func_name, self.ksampler_restart_wrapper)
def cleanup(self):
setattr(k_diffusion_sampling, self.sample_func_name, KSamplerRestartWrapper.ksampler)
@staticmethod
@torch.no_grad()
def ksampler_restart_wrapper(model, x, sigmas, extra_args=None, callback=None, disable=None):
global _total_steps, _restart_segments, _restart_scheduler
ksampler = __class__.ksampler
_restart_scheduler = __class__.scheduler
_restart_segments = __class__.segments
begin_at_step = __class__.begin_at_step
end_at_step = __class__.end_at_step
segments = round_restart_segments(sigmas, _restart_segments)
_total_steps = len(sigmas) - 1 + calc_restart_steps(segments)
step = 0
if not __class__.refiner_stage:
__class__.original_sigmas = sigmas
__class__.total_steps = _total_steps
else:
# Calculate new begin_at_step and end_at_step based on number of steps already completed
step = 0
def callback_wrapper(x):
x["i"] = step
if callback is not None:
callback(x)
real_steps = 0
with trange(_total_steps, disable=disable) as pbar:
for i in range(len(sigmas) - 1):
if real_steps > __class__.cfg_clamp_after_step:
extra_args['cond_scale'] = 1.0
# if i+1 > end_at_step:
# __class__.continuation_step = i
# break
# x = sample_refined_exp_s(model,x,torch.tensor([sigmas[i], sigmas[i + 1]], device=x.device),extra_args=extra_args,callback=callback_wrapper,disable=True)
x = ksampler(model, x, torch.tensor([sigmas[i], sigmas[i + 1]],
device=x.device), extra_args, callback_wrapper, True)
pbar.update(1)
step += 1
real_steps += 1
if sigmas[i].item() in segments:
seg = segments[sigmas[i].item()]
s_min, s_max, k, n_restart = sigmas[i+1], seg['t_max'], seg['k'], seg['n']
seg_sigmas = calc_sigmas(_restart_scheduler, n_restart, s_min,
s_max, model.inner_model, device=x.device)
for _ in range(k):
x += torch.randn_like(x) * (s_max ** 2 - s_min ** 2) ** 0.5
for j in range(n_restart - 1):
# x = sample_refined_exp_s(model,x,torch.tensor([sigmas[j], sigmas[j + 1]], device=x.device),extra_args=extra_args,callback=callback_wrapper,disable=True)
x = ksampler(model, x, torch.tensor(
[seg_sigmas[j], seg_sigmas[j + 1]], device=x.device), extra_args, callback_wrapper, True)
pbar.update(1)
step += 1
return x
class DDIMWrapper(RestartWrapper):
def __init__(self,begin_at_step,end_at_step,segments,scheduler,cfg_clamp_after_step=0):
self.__class__.sample_custom = DDIMSampler.sample_custom
self.__class__.begin_at_step = begin_at_step
self.__class__.end_at_step = end_at_step
self.__class__.segments = segments
self.__class__.scheduler = scheduler
self.__class__.refiner_stage = False
self.__class__.cfg_clamp_after_step = cfg_clamp_after_step
DDIMSampler.sample_custom = self.ddim_wrapper
def cleanup(self):
DDIMSampler.sample_custom = self.__class__.sample_custom
@staticmethod
@torch.no_grad()
def ddim_wrapper(self, ddim_timesteps, conditioning, callback=None, img_callback=None, quantize_x0=False,
eta=0., mask=None, x0=None, temperature=1., noise_dropout=0., score_corrector=None,
corrector_kwargs=None, verbose=True, x_T=None, log_every_t=100, unconditional_guidance_scale=1.,
unconditional_conditioning=None, dynamic_threshold=None, ucg_schedule=None, denoise_function=None,
extra_args=None, to_zero=True, end_step=None, disable_pbar=False, **kwargs):
global _total_steps, _restart_segments, _restart_scheduler
ddim_sampler = __class__.sample_custom
model_denoise = CompVisVDenoiser(self.model)
begin_at_step = __class__.begin_at_step
end_at_step = __class__.end_at_step
_restart_segments = __class__.segments
_restart_scheduler = __class__.scheduler
segments = segments_to_timesteps(_restart_segments, model_denoise)
segments = round_restart_segments_timesteps(ddim_timesteps, segments)
_total_steps = len(ddim_timesteps) - 1 + calc_restart_steps(segments)
step = 0
def callback_wrapper(pred_x0, i):
img_callback(pred_x0, step)
def ddim_simplified(x, timesteps, x_T=None, disable_pbar=False):
if x_T is None:
self.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
x_T = self.stochastic_encode(x, torch.tensor(
[len(timesteps) - 1] * x.shape[0]).to(self.device), noise=torch.zeros_like(x), max_denoise=False)
x, intermediates = ddim_sampler(
self, timesteps, conditioning, callback=callback, img_callback=callback_wrapper, quantize_x0=quantize_x0,
eta=eta, mask=mask, x0=x, temperature=temperature, noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs, verbose=verbose, x_T=x_T, log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold, ucg_schedule=ucg_schedule, denoise_function=denoise_function, extra_args=extra_args,
to_zero=timesteps[0].item() == 0, end_step=len(timesteps) - 1, disable_pbar=disable_pbar
)
return x, intermediates
intermediates = None
real_steps = 0
with trange(_total_steps, disable=disable_pbar) as pbar:
rg = reversed(range(
len(ddim_timesteps) - 1 - min(len(ddim_timesteps) - 1,end_at_step+1), #start
len(ddim_timesteps) - max(begin_at_step,0) #end
)
)
for i in rg:
if real_steps > __class__.cfg_clamp_after_step:
extra_args['cond_scale'] = 1.0
x0, intermediates = ddim_simplified(x0, ddim_timesteps[i:i + 2], x_T=x_T, disable_pbar=True)
x_T = None
pbar.update(1)
step += 1
real_steps += 1
if ddim_timesteps[i].item() in segments:
seg = segments[ddim_timesteps[i].item()]
t_min, t_max, k, n_restart = ddim_timesteps[i], seg['t_max'], seg['k'], seg['n']
s_min, s_max = model_denoise.t_to_sigma(t_min), model_denoise.t_to_sigma(t_max)
seg_sigmas = calc_sigmas(_restart_scheduler, n_restart, s_min,
s_max, model_denoise, device=x0.device)
for _ in range(k):
x0 += torch.randn_like(x0) * (s_max ** 2 - s_min ** 2) ** 0.5
for j in range(n_restart - 1):
seg_ts = model_denoise.sigma_to_t(seg_sigmas[j]).to(torch.int32)
seg_ts_next = model_denoise.sigma_to_t(seg_sigmas[j + 1]).to(torch.int32)
x0, intermediates = ddim_simplified(x0, [seg_ts_next, seg_ts], disable_pbar=True)
pbar.update(1)
step += 1
return x0, intermediates