File size: 2,413 Bytes
82ea528 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import comfy.sd
import comfy.model_sampling
import comfy.latent_formats
import nodes
class InverseCONST:
def calculate_input(self, sigma, noise):
return noise
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_output
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
return latent_image
def inverse_noise_scaling(self, sigma, latent):
return latent
class LTXForwardModelSamplingPredNode:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "ltxtricks"
def patch(self, model):
m = model.clone()
sampling_base = comfy.model_sampling.ModelSamplingFlux
sampling_type = InverseCONST
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass
model_sampling = ModelSamplingAdvanced(model.model.model_config)
model_sampling.set_parameters(shift=1.15)
m.add_object_patch("model_sampling", model_sampling)
return (m, )
class ReverseCONST:
def calculate_input(self, sigma, noise):
return noise
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_output # model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
return latent_image
def inverse_noise_scaling(self, sigma, latent):
return latent / (1.0 - sigma)
class LTXReverseModelSamplingPredNode:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "ltxtricks"
def patch(self, model):
m = model.clone()
sampling_base = comfy.model_sampling.ModelSamplingFlux
sampling_type = ReverseCONST
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass
model_sampling = ModelSamplingAdvanced(model.model.model_config)
model_sampling.set_parameters(shift=1.15)
m.add_object_patch("model_sampling", model_sampling)
return (m, )
|