File size: 18,008 Bytes
6fecfbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
import ast
from itertools import tee

from tqdm.auto import trange

from comfy.k_diffusion import sampling as k_diffusion_sampling
from comfy.ldm.models.diffusion.ddim import DDIMSampler
from comfy.samplers import CompVisVDenoiser
from comfy.utils import ProgressBar
from nodes import common_ksampler
import torch

from ..utils.refined_exp_solver import RefinedExpCallbackPayload, _refined_exp_sosu_step
from .restart_schedulers import SCHEDULER_MAPPING

def pairwise(iterable):
    "s -> (s0, s1), (s1, s2), (s2, s3), ..."
    a, b = tee(iterable)
    next(b, None)
    return zip(a, b)


def add_restart_segment(restart_segments, n_restart, k, t_min, t_max):
    if restart_segments is None:
        restart_segments = []
    restart_segments.append({'n': n_restart, 'k': k, 't_min': t_min, 't_max': t_max})
    return restart_segments


def prepare_restart_segments(restart_info):
    try:
        restart_arrays = ast.literal_eval(f"[{restart_info}]")
    except SyntaxError as e:
        print("Ill-formed restart segments")
        raise e
    restart_segments = []
    for arr in restart_arrays:
        if len(arr) != 4:
            raise ValueError("Restart segment must have 4 values")
        n_restart, k, t_min, t_max = arr
        restart_segments = add_restart_segment(restart_segments, n_restart, k, t_min, t_max)
    return restart_segments


def round_restart_segments(sigmas, restart_segments):
    s_min, s_max = sigmas[-1], sigmas[0]
    t_min_mapping = {}
    for segment in reversed(restart_segments):  # Reversed to prioritize segments to the front
        if segment['t_max'] > s_max:
            continue #toss the segment
        t_min_neighbor = min(sigmas, key=lambda s: abs(s - segment['t_min'])).item()
        t_min_mapping[t_min_neighbor] = {'n': segment['n'], 'k': segment['k'], 't_max': segment['t_max']}
    return t_min_mapping


def segments_to_timesteps(restart_segments, model):
    timesteps = []
    for segment in restart_segments:
        t_min, t_max = model.sigma_to_t(torch.tensor(
            [segment['t_min'], segment['t_max']], device=model.log_sigmas.device))
        ts_segment = {'n': segment['n'], 'k': segment['k'], 't_min': t_min, 't_max': t_max}
        timesteps.append(ts_segment)
    return timesteps


def round_restart_segments_timesteps(timesteps, restart_segments):
    t_min_mapping = {}
    for segment in reversed(restart_segments):  # Reversed to prioritize segments to the front
        t_min_neighbor = min(timesteps, key=lambda ts: abs(ts - segment['t_min'])).item()
        t_min_mapping[t_min_neighbor] = {'n': segment['n'], 'k': segment['k'], 't_max': segment['t_max']}
    return t_min_mapping


def calc_sigmas(scheduler, n, sigma_min, sigma_max, model, device):
    return SCHEDULER_MAPPING[scheduler](model, n, sigma_min, sigma_max, device)


def calc_restart_steps(restart_segments):
    restart_steps = 0
    for segment in restart_segments.values():
        restart_steps += (segment['n'] - 1) * segment['k']
    return restart_steps


_total_steps = 0
_restart_segments = None
_restart_scheduler = None


def restart_sampling(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise, restart_info, restart_scheduler, begin_at_step, end_at_step, disable_noise, force_full_denoise):
    global _total_steps, _restart_segments, _restart_scheduler
    _restart_scheduler = restart_scheduler
    _restart_segments = prepare_restart_segments(restart_info)

    if sampler_name == "res":
        sampler_wrapper = RESWrapper(begin_at_step, end_at_step, _restart_segments, _restart_scheduler)
    elif sampler_name == "ddim":
        sampler_wrapper = DDIMWrapper(begin_at_step, end_at_step, _restart_segments, _restart_scheduler)
    else:
        sampler_wrapper = KSamplerRestartWrapper(sampler_name, begin_at_step, end_at_step, _restart_segments, _restart_scheduler)

    # Add the additional steps to the progress bar
    pbar_update_absolute = ProgressBar.update_absolute

    def pbar_update_absolute_wrapper(self, value, total=None, preview=None):
        pbar_update_absolute(self, value, _total_steps, preview)

    ProgressBar.update_absolute = pbar_update_absolute_wrapper

    try:
        samples = common_ksampler(model, seed, steps, cfg, sampler_name, scheduler,
                                  positive, negative, latent_image, denoise=denoise, force_full_denoise=force_full_denoise, disable_noise=disable_noise)
    finally:
        sampler_wrapper.cleanup()
        ProgressBar.update_absolute = pbar_update_absolute
    return samples


class RestartWrapper:

    def cleanup(self):
        pass

class RESWrapper(RestartWrapper):
    def __init__(self, begin_at_step,end_at_step,segments,scheduler,cfg_clamp_after_step=0):
        self.__class__.segments = segments
        self.__class__.scheduler = scheduler
        self.__class__.refiner_stage = False
        self.__class__.cfg_clamp_after_step = cfg_clamp_after_step
        self.sample_func_name = "sample_res"
        setattr(k_diffusion_sampling, self.sample_func_name, self.ksampler_restart_wrapper)

    @staticmethod
    @torch.no_grad()
    def ksampler_restart_wrapper(model, x, sigmas, extra_args=None, callback=None, disable=None):
        global _total_steps, _restart_segments, _restart_scheduler
        _restart_scheduler = __class__.scheduler
        _restart_segments = __class__.segments
        segments = round_restart_segments(sigmas, _restart_segments)
        _total_steps = len(sigmas) - 1 + calc_restart_steps(segments)
        step = 0
        real_steps = 0
            
        def callback_wrapper(x):
            x["i"] = step
            if callback is not None:
                callback(x)
        ita: torch.FloatTensor = torch.zeros((1,),device=x.device)
        simple_phi_calc = True
        c2 = .5
        with trange(_total_steps, disable=disable) as pbar:
            for i, (sigma, sigma_next) in enumerate(pairwise(sigmas[:-1].split(1))):
                if real_steps > __class__.cfg_clamp_after_step:
                    extra_args['cond_scale'] = 1.0
                
                eps: torch.FloatTensor = torch.randn_like(x,device=x.device)
                sigma_hat = sigma * (1 + ita)
                x_hat = x + (sigma_hat ** 2 - sigma ** 2) ** .5 * eps
                x_next, denoised, denoised2 = _refined_exp_sosu_step(
                    model,
                    x_hat,
                    sigma_hat,
                    sigma_next,
                    c2=c2,
                    extra_args=extra_args,
                    pbar=pbar,
                    simple_phi_calc=simple_phi_calc,
                )
                if callback is not None:
                    payload = RefinedExpCallbackPayload(
                    x=x,
                    i=step,
                    sigma=sigma,
                    sigma_hat=sigma_hat,
                    denoised=denoised,
                    denoised2=denoised2,
                    )
                    callback(payload)
                x = x_next
                pbar.update(1)
                step += 1
                real_steps += 1
                if sigmas[i].item() in segments:
                    seg = segments[sigmas[i].item()]
                    s_min, s_max, k, n_restart = sigmas[i+1], seg['t_max'], seg['k'], seg['n']
                    seg_sigmas = calc_sigmas(_restart_scheduler, n_restart, s_min,
                                             s_max, model.inner_model, device=x.device)
                    for _ in range(k):
                        #x += torch.randn_like(x) * (s_max ** 2 - s_min ** 2) ** 0.5
                        for j in range(n_restart - 1):
                            eps: torch.FloatTensor = torch.randn_like(x,device=x.device)
                            sigma_hat = seg_sigmas[j] * (1 + ita)
                            x_hat = x + (sigma_hat ** 2 - seg_sigmas[j] ** 2) ** .5 * eps
                            x_next, denoised, denoised2 = _refined_exp_sosu_step(
                                model,
                                x_hat,
                                sigma_hat,
                                seg_sigmas[j+1],
                                c2=c2,
                                extra_args=extra_args,
                                pbar=pbar,
                                simple_phi_calc=simple_phi_calc,
                            )
                            # x = sample_refined_exp_s(model,x,torch.tensor([sigmas[j], sigmas[j + 1]], device=x.device),extra_args=extra_args,callback=callback_wrapper,disable=True)
                            # x = ksampler(model, x, torch.tensor(
                            #     [seg_sigmas[j], seg_sigmas[j + 1]], device=x.device), extra_args, callback_wrapper, True)
                            pbar.update(1)
                            step += 1
            if __class__.refiner_stage:
                eps: torch.FloatTensor = torch.randn_like(x,device=x.device)
                sigma_hat = sigma * (1 + ita)
                x_hat = x + (sigma_hat ** 2 - sigma ** 2) ** .5 * eps
                x_next: torch.FloatTensor = model(x_hat, sigma.to(x_hat.device),**extra_args)
                pbar.update()
                x = x_next
        return x

class KSamplerRestartWrapper(RestartWrapper):

    ksampler = None

    def __init__(self, sampler_name,begin_at_step,end_at_step, segments, scheduler,cfg_clamp_after_step=0):
        self.sample_func_name = "sample_{}".format(sampler_name)
        self.__class__.segments = segments
        self.__class__.ksampler = getattr(k_diffusion_sampling, self.sample_func_name)
        self.__class__.begin_at_step = begin_at_step
        self.__class__.end_at_step = end_at_step
        self.__class__.scheduler = scheduler
        self.__class__.refiner_stage = False
        self.__class__.original_sigmas = None
        self.__class__.total_steps = 0
        self.__class__.continuation_step = 0
        self.__class__.cfg_clamp_after_step = cfg_clamp_after_step
        setattr(k_diffusion_sampling, self.sample_func_name, self.ksampler_restart_wrapper)

    def cleanup(self):
        setattr(k_diffusion_sampling, self.sample_func_name, KSamplerRestartWrapper.ksampler)

    @staticmethod
    @torch.no_grad()
    def ksampler_restart_wrapper(model, x, sigmas, extra_args=None, callback=None, disable=None):
        global _total_steps, _restart_segments, _restart_scheduler
        ksampler = __class__.ksampler
        _restart_scheduler = __class__.scheduler
        _restart_segments = __class__.segments
        begin_at_step = __class__.begin_at_step
        end_at_step = __class__.end_at_step
        segments = round_restart_segments(sigmas, _restart_segments)
        _total_steps = len(sigmas) - 1 + calc_restart_steps(segments)
        step = 0
        if not __class__.refiner_stage:
            __class__.original_sigmas = sigmas
            __class__.total_steps = _total_steps
        else:
            # Calculate new begin_at_step and end_at_step based on number of steps already completed
            step = 0
            
        def callback_wrapper(x):
            x["i"] = step
            if callback is not None:
                callback(x)
        
        real_steps = 0
        with trange(_total_steps, disable=disable) as pbar:
            for i in range(len(sigmas) - 1):
                if real_steps > __class__.cfg_clamp_after_step:
                    extra_args['cond_scale'] = 1.0
                # if i+1 > end_at_step:
                #     __class__.continuation_step = i
                #     break
                # x = sample_refined_exp_s(model,x,torch.tensor([sigmas[i], sigmas[i + 1]], device=x.device),extra_args=extra_args,callback=callback_wrapper,disable=True)
                x = ksampler(model, x, torch.tensor([sigmas[i], sigmas[i + 1]],
                                                    device=x.device), extra_args, callback_wrapper, True)
                pbar.update(1)
                step += 1
                real_steps += 1
                if sigmas[i].item() in segments:
                    seg = segments[sigmas[i].item()]
                    s_min, s_max, k, n_restart = sigmas[i+1], seg['t_max'], seg['k'], seg['n']
                    seg_sigmas = calc_sigmas(_restart_scheduler, n_restart, s_min,
                                             s_max, model.inner_model, device=x.device)
                    for _ in range(k):
                        x += torch.randn_like(x) * (s_max ** 2 - s_min ** 2) ** 0.5
                        for j in range(n_restart - 1):
                            # x = sample_refined_exp_s(model,x,torch.tensor([sigmas[j], sigmas[j + 1]], device=x.device),extra_args=extra_args,callback=callback_wrapper,disable=True)
                            x = ksampler(model, x, torch.tensor(
                                [seg_sigmas[j], seg_sigmas[j + 1]], device=x.device), extra_args, callback_wrapper, True)
                            pbar.update(1)
                            step += 1
        return x


class DDIMWrapper(RestartWrapper):

    def __init__(self,begin_at_step,end_at_step,segments,scheduler,cfg_clamp_after_step=0):
        self.__class__.sample_custom = DDIMSampler.sample_custom
        self.__class__.begin_at_step = begin_at_step
        self.__class__.end_at_step = end_at_step
        self.__class__.segments = segments
        self.__class__.scheduler = scheduler
        self.__class__.refiner_stage = False
        self.__class__.cfg_clamp_after_step = cfg_clamp_after_step
        DDIMSampler.sample_custom = self.ddim_wrapper

    def cleanup(self):
        DDIMSampler.sample_custom = self.__class__.sample_custom

    @staticmethod
    @torch.no_grad()
    def ddim_wrapper(self, ddim_timesteps, conditioning, callback=None, img_callback=None, quantize_x0=False,
                     eta=0., mask=None, x0=None, temperature=1., noise_dropout=0., score_corrector=None,
                     corrector_kwargs=None, verbose=True, x_T=None, log_every_t=100, unconditional_guidance_scale=1.,
                     unconditional_conditioning=None, dynamic_threshold=None, ucg_schedule=None, denoise_function=None,
                     extra_args=None, to_zero=True, end_step=None, disable_pbar=False, **kwargs):
        global _total_steps, _restart_segments, _restart_scheduler
        ddim_sampler = __class__.sample_custom
        model_denoise = CompVisVDenoiser(self.model)
        
        begin_at_step = __class__.begin_at_step
        end_at_step = __class__.end_at_step
        _restart_segments = __class__.segments
        _restart_scheduler = __class__.scheduler
        segments = segments_to_timesteps(_restart_segments, model_denoise)
        segments = round_restart_segments_timesteps(ddim_timesteps, segments)
        _total_steps = len(ddim_timesteps) - 1 + calc_restart_steps(segments)
        step = 0

        def callback_wrapper(pred_x0, i):
            img_callback(pred_x0, step)

        def ddim_simplified(x, timesteps, x_T=None, disable_pbar=False):
            if x_T is None:
                self.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
                x_T = self.stochastic_encode(x, torch.tensor(
                    [len(timesteps) - 1] * x.shape[0]).to(self.device), noise=torch.zeros_like(x), max_denoise=False)
            x, intermediates = ddim_sampler(
                self, timesteps, conditioning, callback=callback, img_callback=callback_wrapper, quantize_x0=quantize_x0,
                eta=eta, mask=mask, x0=x, temperature=temperature, noise_dropout=noise_dropout, score_corrector=score_corrector,
                corrector_kwargs=corrector_kwargs, verbose=verbose, x_T=x_T, log_every_t=log_every_t,
                unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning,
                dynamic_threshold=dynamic_threshold, ucg_schedule=ucg_schedule, denoise_function=denoise_function, extra_args=extra_args,
                to_zero=timesteps[0].item() == 0, end_step=len(timesteps) - 1, disable_pbar=disable_pbar
            )
            return x, intermediates
        intermediates = None
        real_steps = 0
        with trange(_total_steps, disable=disable_pbar) as pbar:
            
            rg = reversed(range(
                len(ddim_timesteps) - 1 - min(len(ddim_timesteps) - 1,end_at_step+1), #start
                len(ddim_timesteps) - max(begin_at_step,0) #end
                )
            )      
            for i in rg:
                if real_steps > __class__.cfg_clamp_after_step:
                    extra_args['cond_scale'] = 1.0
                
                x0, intermediates = ddim_simplified(x0, ddim_timesteps[i:i + 2], x_T=x_T, disable_pbar=True)
                x_T = None
                pbar.update(1)
                step += 1
                real_steps += 1
                if ddim_timesteps[i].item() in segments:
                    seg = segments[ddim_timesteps[i].item()]
                    t_min, t_max, k, n_restart = ddim_timesteps[i], seg['t_max'], seg['k'], seg['n']
                    s_min, s_max = model_denoise.t_to_sigma(t_min), model_denoise.t_to_sigma(t_max)
                    seg_sigmas = calc_sigmas(_restart_scheduler, n_restart, s_min,
                                             s_max, model_denoise, device=x0.device)
                    for _ in range(k):
                        x0 += torch.randn_like(x0) * (s_max ** 2 - s_min ** 2) ** 0.5
                        for j in range(n_restart - 1):
                            seg_ts = model_denoise.sigma_to_t(seg_sigmas[j]).to(torch.int32)
                            seg_ts_next = model_denoise.sigma_to_t(seg_sigmas[j + 1]).to(torch.int32)
                            x0, intermediates = ddim_simplified(x0, [seg_ts_next, seg_ts], disable_pbar=True)
                            pbar.update(1)
                            step += 1
        return x0, intermediates