File size: 4,675 Bytes
82ea528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
from tqdm import trange

from comfy.samplers import KSAMPLER, CFGGuider, sampling_function


class FlowEditGuider(CFGGuider):
    def __init__(self, model_patcher):
        super().__init__(model_patcher)
        self.cfgs = {}

    def set_conds(self, **kwargs):
        self.inner_set_conds(kwargs)

    def set_cfgs(self, **kwargs):
        self.cfgs = {**kwargs}

    def predict_noise(self, x, timestep, model_options={}, seed=None):
        latent_type = model_options['transformer_options']['latent_type']
        positive = self.conds.get(f'{latent_type}_positive', None)
        negative = self.conds.get(f'{latent_type}_negative', None)
        cfg = self.cfgs.get(latent_type, self.cfg)
        return sampling_function(self.inner_model, x, timestep, negative, positive, cfg, model_options=model_options, seed=seed)


class LTXFlowEditCFGGuiderNode:
    @classmethod
    def INPUT_TYPES(s):
        return {"required":
                    {
                        "model": ("MODEL",),
                        "source_pos": ("CONDITIONING", ),
                        "source_neg": ("CONDITIONING", ),
                        "target_pos": ("CONDITIONING", ),
                        "target_neg": ("CONDITIONING", ),
                        "source_cfg": ("FLOAT", {"default": 2, "min": 0, "max": 0xffffffffffffffff, "step": 0.01 }),
                        "target_cfg": ("FLOAT", {"default": 4.5, "min": 0, "max": 0xffffffffffffffff, "step": 0.01 }),
                     }
                }

    RETURN_TYPES = ("GUIDER",)

    FUNCTION = "get_guider"
    CATEGORY = "ltxtricks"

    def get_guider(self, model, source_pos, source_neg, target_pos, target_neg, source_cfg, target_cfg):
        guider = FlowEditGuider(model)
        guider.set_conds(source_positive=source_pos, source_negative=source_neg, target_positive=target_pos, target_negative=target_neg)
        guider.set_cfgs(source=source_cfg, target=target_cfg)
        return (guider,)


def get_flowedit_sample(skip_steps, refine_steps, seed):
    generator = torch.manual_seed(seed)
    @torch.no_grad()
    def flowedit_sample(model, x_init, sigmas, extra_args=None, callback=None, disable=None):
        extra_args = {} if extra_args is None else extra_args

        model_options = extra_args.get('model_options', {})
        transformer_options = model_options.get('transformer_options', {})
        transformer_options = {**transformer_options}
        model_options['transformer_options'] = transformer_options
        extra_args['model_options'] = model_options

        source_extra_args = {**extra_args, 'model_options': { 'transformer_options': { **transformer_options,'latent_type': 'source '} }}

        sigmas = sigmas[skip_steps:]

        x_tgt = x_init.clone()
        N = len(sigmas)-1
        s_in = x_init.new_ones([x_init.shape[0]])

        for i in trange(N, disable=disable):
            sigma = sigmas[i]
            noise = torch.randn(x_init.shape, generator=generator).to(x_init.device)

            zt_src = (1-sigma)*x_init + sigma*noise
            
            if i < N-refine_steps:
                zt_tgt = x_tgt + zt_src - x_init
                transformer_options['latent_type'] = 'source'
                source_extra_args['model_options']['transformer_options']['latent_type'] = 'source'
                vt_src = model(zt_src, sigma*s_in, **source_extra_args)
            else:
                if i == N-refine_steps:
                    x_tgt = x_tgt + zt_src - x_init
                zt_tgt = x_tgt
                vt_src = 0
                
            transformer_options['latent_type'] = 'target'
            vt_tgt = model(zt_tgt, sigma*s_in, **extra_args)
            
            v_delta = vt_tgt - vt_src
            x_tgt += (sigmas[i+1] - sigmas[i]) * v_delta
            
            if callback is not None:
                callback({'x': x_tgt, 'denoised': x_tgt, 'i': i+skip_steps, 'sigma': sigmas[i], 'sigma_hat': sigmas[i]})

        return x_tgt
    
    return flowedit_sample


class LTXFlowEditSamplerNode:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { 
            "skip_steps": ("INT", {"default": 4, "min": 0, "max": 0xffffffffffffffff }),
            "refine_steps": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff }),
            "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff }),
        }, "optional": {
        }}
    RETURN_TYPES = ("SAMPLER",)
    FUNCTION = "build"

    CATEGORY = "ltxtricks"

    def build(self, skip_steps, refine_steps, seed):
        sampler = KSAMPLER(get_flowedit_sample(skip_steps, refine_steps, seed))
        return (sampler, )