import random import comfy import comfy.model_management import comfy.sample import comfy.samplers import comfy.utils from comfy.utils import ProgressBar import numpy as np import torch from ..utils import VyroParams from ..utils.restart_sampling import ( DDIMWrapper, KSamplerRestartWrapper, RESWrapper, SCHEDULER_MAPPING, _restart_scheduler, _restart_segments, _total_steps, prepare_restart_segments, ) from ..utils.sdxl_ksampler import CfgMethods, sdxl_ksampler def get_supported_samplers(): samplers = comfy.samplers.KSampler.SAMPLERS.copy() samplers.remove("uni_pc") samplers.remove("uni_pc_bh2") # SDE samplers cannot be used with restarts samplers.remove("dpmpp_sde") samplers.remove("dpmpp_sde_gpu") samplers.remove("dpmpp_2m_sde") samplers.remove("dpmpp_2m_sde_gpu") samplers.append("res") return samplers def get_supported_restart_schedulers(): return list(SCHEDULER_MAPPING.keys()) def sdxl_restarts_ksampler(base_model, refiner_model, seed, base_steps, refiner_steps, cfg, sampler_name, scheduler, base_positive, base_negative, refiner_positive, refiner_negative, latent, segments, restart_scheduler, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, refiner_detail_boost=0.0,cfg_clamp_after_step=0): global _total_steps, _restart_segments, _restart_scheduler _restart_scheduler = restart_scheduler _restart_segments = prepare_restart_segments(segments) if sampler_name == "res": sampler_wrapper = RESWrapper(start_step, last_step, _restart_segments, _restart_scheduler, cfg_clamp_after_step) elif sampler_name == "ddim": sampler_wrapper = DDIMWrapper(start_step, last_step, _restart_segments, _restart_scheduler, cfg_clamp_after_step) else: sampler_wrapper = KSamplerRestartWrapper(sampler_name, start_step, last_step, _restart_segments, _restart_scheduler, cfg_clamp_after_step) # Add the additional steps to the progress bar pbar_update_absolute = ProgressBar.update_absolute def pbar_update_absolute_wrapper(self, value, total=None, preview=None): pbar_update_absolute(self, value, _total_steps, preview) ProgressBar.update_absolute = pbar_update_absolute_wrapper try: result = sdxl_ksampler(base_model, refiner_model, seed, base_steps, refiner_steps, cfg, sampler_name, scheduler, base_positive, base_negative, refiner_positive, refiner_negative, latent, denoise=denoise, disable_noise=False, start_step=0, last_step=last_step, force_full_denoise=force_full_denoise, dynamic_base_cfg=cfg, dynamic_refiner_cfg=cfg, cfg_method=CfgMethods.TONEMAP, refiner_detail_boost=refiner_detail_boost, restart_wrapper=sampler_wrapper) return result finally: sampler_wrapper.cleanup() ProgressBar.update_absolute = pbar_update_absolute class VyroSDXLSampler: @classmethod def INPUT_TYPES(s): return { "required": { "params": ("VYRO_PARAMS",), "base_model": ("MODEL",), "base_positive": ("CONDITIONING", ), "base_negative": ("CONDITIONING", ), "sampler_name": (get_supported_samplers(), ), "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ), "base_ratio": ("FLOAT", {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.01}), "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "refiner_detail_boost": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.05},), }, "optional": { "refiner_model": ("MODEL",), "refiner_positive": ("CONDITIONING", ), "refiner_negative": ("CONDITIONING", ), "latent_image": ("LATENT",), } } RETURN_TYPES = ("LATENT",) FUNCTION = "sample" CATEGORY = "Vyro/Samplers" def sample(self, params:VyroParams, base_model, base_positive, base_negative, sampler_name, scheduler, base_ratio, denoise, refiner_detail_boost, refiner_model=None, refiner_positive=None, refiner_negative=None, latent_image=None): if latent_image is None: latent_image = params.latents noise_seed = params.seed steps = params.steps cfg = params.cfg cfg_method = CfgMethods.TONEMAP dynamic_base_cfg = 0.0 dynamic_refiner_cfg = 0.0 has_refiner_model = refiner_model is not None base_steps = int(steps * (base_ratio + 0.0001)) if has_refiner_model else steps refiner_steps = max(0, steps - base_steps) if denoise < 0.005: return (params.latents,) torch.manual_seed(params.seed) random.seed(params.seed) np.random.seed(params.seed) if refiner_steps == 0 or not has_refiner_model: result = sdxl_ksampler(base_model, None, noise_seed, base_steps, 0, cfg, sampler_name, scheduler, base_positive, base_negative, None, None, latent_image, denoise=denoise, disable_noise=False, start_step=0, last_step=steps, force_full_denoise=True, dynamic_base_cfg=dynamic_base_cfg, cfg_method=cfg_method) else: result = sdxl_ksampler(base_model, refiner_model, noise_seed, base_steps, refiner_steps, cfg, sampler_name, scheduler, base_positive, base_negative, refiner_positive, refiner_negative, latent_image, denoise=denoise, disable_noise=False, start_step=0, last_step=steps, force_full_denoise=True, dynamic_base_cfg=dynamic_base_cfg, dynamic_refiner_cfg=dynamic_refiner_cfg, cfg_method=cfg_method, refiner_detail_boost=refiner_detail_boost) return result class VyroKRestartSampler: def __init__(self) -> None: if 'res' not in comfy.samplers.KSampler.SAMPLERS: comfy.samplers.KSampler.SAMPLERS.append('res') @classmethod def INPUT_TYPES(s): return { "required": { "params": ("VYRO_PARAMS",), "base_model": ("MODEL",), "base_positive": ("CONDITIONING", ), "base_negative": ("CONDITIONING", ), "refiner_model": ("MODEL",), "refiner_positive": ("CONDITIONING", ), "refiner_negative": ("CONDITIONING", ), "sampler_name": (get_supported_samplers(), ), "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ), "segments": ("STRING", {"default": "[3,2,0.06,0.30],[3,1,0.30,0.59]", "multiline": False}), "restart_scheduler": (get_supported_restart_schedulers(), ), "begin_at_step": ("INT", {"default": 1, "min": 0, "max": 10000}), "end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}), "add_noise": (["enable", "disable"], ), "return_with_leftover_noise": (["disable", "enable"], ), "base_ratio": ("FLOAT", {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.01}), "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "refiner_detail_boost": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.05},), "cfg_clamp_after_step": ("INT", {"default": 0, "min": 0, "max": 10000}), }, "optional": { "refiner_prep_steps": ("INT", {"default": 0, "min": 0, "max": 10}), "latent_image": ("LATENT",), }, } RETURN_TYPES = ("LATENT", "MODEL",) FUNCTION = "sample" CATEGORY = "Vyro/Samplers" def sample(self, params:VyroParams, base_model, base_positive, base_negative, refiner_model, refiner_positive, refiner_negative, sampler_name, scheduler, segments, restart_scheduler, begin_at_step, end_at_step, add_noise, return_with_leftover_noise, base_ratio, denoise, refiner_detail_boost, refiner_prep_steps=0, cfg_clamp_after_step=0, latent_image=None): force_full_denoise = True if return_with_leftover_noise == "enable": force_full_denoise = False disable_noise = False if add_noise == "disable": disable_noise = True steps = params.steps base_steps = int(steps * (base_ratio + 0.0001)) refiner_steps = max(0, steps - base_steps) if latent_image is None: input_latent = latent_image = params.latents else: input_latent = latent_image if denoise < 0.01: return (latent_image, ) # if refiner_prep_steps is not None: # if refiner_prep_steps >= base_steps: # refiner_prep_steps = base_steps - 1 # if refiner_prep_steps > 0: # start_at_step = refiner_prep_steps # precondition_result = nodes.common_ksampler(refiner_model, params.seed + 2, steps, params.cfg, sampler_name, scheduler, refiner_positive, refiner_negative, latent_image, denoise=denoise, disable_noise=False, start_step=steps - refiner_prep_steps, last_step=steps, force_full_denoise=False) # input_latent = precondition_result[0] torch.manual_seed(params.seed) random.seed(params.seed) np.random.seed(params.seed) if base_steps >= steps: out = sdxl_restarts_ksampler(base_model, None, params.seed, base_steps, refiner_steps, params.cfg, sampler_name, scheduler, base_positive, base_negative, refiner_positive, refiner_negative, input_latent, segments, restart_scheduler=restart_scheduler, denoise=denoise, disable_noise=disable_noise, start_step=begin_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise, refiner_detail_boost=refiner_detail_boost, cfg_clamp_after_step=cfg_clamp_after_step)[0] # return restart_sampling(base_model, params.seed, steps, params.cfg, sampler_name, scheduler, base_positive, base_negative, input_latent, denoise, segments, restart_scheduler, begin_at_step, end_at_step, disable_noise, force_full_denoise) return (out, base_model) out = sdxl_restarts_ksampler(base_model, refiner_model, params.seed, base_steps, refiner_steps, params.cfg, sampler_name, scheduler, base_positive, base_negative, refiner_positive, refiner_negative, input_latent, segments, restart_scheduler=restart_scheduler, denoise=denoise, disable_noise=disable_noise, start_step=begin_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise, refiner_detail_boost=refiner_detail_boost, cfg_clamp_after_step=cfg_clamp_after_step)[0] return (out, base_model) NODE_CLASS_MAPPINGS = { "Vyro KRestart Sampler": VyroKRestartSampler, "Vyro SDXL Sampler": VyroSDXLSampler, } NODE_DISPLAY_NAME_MAPPINGS = { "VyroKRestartSampler": "Vyro KRestart Sampler", "VyroSDXLSampler": "Vyro SDXL Sampler", }