jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
log = logging.getLogger(__name__)
import comfy.model_management as mm
from ..sampling.sampler import run_sampler
from ..sampling.sampling_functions import get_rf_forward_sample_fn
from ..utils.sampling_utils import prepare_conds
class MochiWrapperUnsamplerNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MOCHIMODEL",),
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"gamma": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 30.0, "step": 0.01}),
"sigmas": ("SIGMAS", {"tooltip": "Override sigma schedule and steps"}),
"latents": ("LATENT", ),
}
}
RETURN_TYPES = ("LATENT",)
RETURN_NAMES = ("samples",)
FUNCTION = "process"
CATEGORY = "MochiEdit"
def process(self, model, positive, negative, seed, gamma, sigmas, latents):
mm.soft_empty_cache()
sigmas = sigmas.tolist()
if sigmas[0] != 0.0:
sigmas = [0.0, *sigmas]
latents = latents['samples']
positive, negative = prepare_conds(positive, negative)
sampler_fn = get_rf_forward_sample_fn(gamma, seed)
latents = run_sampler(model, latents, positive, negative, sigmas, 1.0, sampler_fn, False, 0)
mm.soft_empty_cache()
return ({"samples": latents},)