File size: 4,491 Bytes
82ea528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
from tqdm import trange

from comfy.samplers import KSAMPLER, CFGGuider, sampling_function


class FlowEditGuider(CFGGuider):
    def __init__(self, model_patcher):
        super().__init__(model_patcher)
        self.cfgs = {}

    def set_conds(self, **kwargs):
        self.inner_set_conds(kwargs)

    def set_cfgs(self, **kwargs):
        self.cfgs = {**kwargs}

    def predict_noise(self, x, timestep, model_options={}, seed=None):
        latent_type = model_options['transformer_options']['latent_type']
        positive = self.conds.get(f'{latent_type}_positive', None)
        negative = self.conds.get(f'{latent_type}_negative', None)
        cfg = self.cfgs.get(latent_type, self.cfg)
        return sampling_function(self.inner_model, x, timestep, negative, positive, cfg, model_options=model_options, seed=seed)


class HYFlowEditGuiderNode:
    @classmethod
    def INPUT_TYPES(s):
        return {"required":
                    {
                        "model": ("MODEL",),
                        "source_cond": ("CONDITIONING", ),
                        "target_cond": ("CONDITIONING", ),
                     }
                }

    RETURN_TYPES = ("GUIDER",)

    FUNCTION = "get_guider"
    CATEGORY = "hunyuanloom"

    def get_guider(self, model, source_cond, target_cond):
        guider = FlowEditGuider(model)
        guider.set_conds(source_positive=source_cond, target_positive=target_cond)
        guider.set_cfg(1.0)
        return (guider,)
    

def get_flowedit_sample(skip_steps, refine_steps, generator):
    @torch.no_grad()
    def flowedit_sample(model, x_init, sigmas, extra_args=None, callback=None, disable=None):
        extra_args = {} if extra_args is None else extra_args

        model_options = extra_args.get('model_options', {})
        transformer_options = model_options.get('transformer_options', {})
        transformer_options = {**transformer_options}
        model_options['transformer_options'] = transformer_options
        extra_args['model_options'] = model_options

        source_extra_args = {**extra_args, 'model_options': { 'transformer_options': { 'latent_type': 'source '} }}

        sigmas = sigmas[skip_steps:]

        x_tgt = x_init.clone()
        N = len(sigmas)-1
        s_in = x_init.new_ones([x_init.shape[0]])
        noise_mask = extra_args.get('denoise_mask', None)
        if noise_mask is None:
            noise_mask = torch.ones_like(x_init)
        else:
            extra_args['denoise_mask'] = None
            source_extra_args['denoise_mask'] = None
        #     noise_mask = 1 - noise_mask

        for i in trange(N, disable=disable):
            sigma = sigmas[i]
            noise = torch.randn(x_init.shape, generator=generator).to(x_init.device)

            zt_src = (1-sigma)*x_init + sigma*noise
            
            if i < N-refine_steps:
                zt_tgt = x_tgt + zt_src - x_init
                vt_src = model(zt_src, sigma*s_in, **source_extra_args)
            else:
                if i == N-refine_steps:
                    zt_tgt = x_tgt + (zt_src - x_init)
                    x_tgt = x_tgt + (zt_src - x_init) * noise_mask
                else:
                    zt_tgt = x_tgt * (noise_mask) + (1-noise_mask) * ( (1-sigma)*x_tgt + sigma*noise )
                vt_src = 0
                
            transformer_options['latent_type'] = 'target'
            vt_tgt = model(zt_tgt, sigma*s_in, **extra_args)
            
            v_delta = vt_tgt - vt_src
            x_tgt += (sigmas[i+1] - sigmas[i]) * v_delta * noise_mask
            
            if callback is not None:
                callback({'x': x_tgt, 'denoised': x_tgt, 'i': i+skip_steps, 'sigma': sigmas[i], 'sigma_hat': sigmas[i]})

        return x_tgt
    
    return flowedit_sample


class HYFlowEditSamplerNode:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { 
            "skip_steps": ("INT", {"default": 4, "min": 0, "max": 0xffffffffffffffff }),
            "drift_steps": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff }),
            "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff }),
        }, "optional": {
        }}
    RETURN_TYPES = ("SAMPLER",)
    FUNCTION = "build"

    CATEGORY = "hunyuanloom"

    def build(self, skip_steps, drift_steps, seed):
        generator = torch.manual_seed(seed)
        sampler = KSAMPLER(get_flowedit_sample(skip_steps, drift_steps, generator))
        return (sampler, )