File size: 42,092 Bytes
d9f3559
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
# This extension works with [Mikubill/sd-webui-controlnet](https://github.com/Mikubill/sd-webui-controlnet)
# version: v1.1.229

LOG_PREFIX = '[ControlNet-Travel]'

# ↓↓↓ EXIT EARLY IF EXTERNAL REPOSITORY NOT FOUND ↓↓↓

CTRLNET_REPO_NAME = 'sdcontrol'
if 'externel repo sanity check':
    from pathlib import Path
    from modules.scripts import basedir
    from traceback import print_exc

    ME_PATH = Path(basedir())
    CTRLNET_PATH = ME_PATH.parent / 'sdcontrol'

    controlnet_found = False
    try:
        import sys ; sys.path.append(str(CTRLNET_PATH))
        #from scripts.controlnet import Script as ControlNetScript  # NOTE: this will mess up the import order
        from scripts.external_code import ControlNetUnit
        from scripts.hook import UNetModel, UnetHook, ControlParams
        from scripts.hook import *

        controlnet_found = True
        print(f'{LOG_PREFIX} extension {CTRLNET_REPO_NAME} found, ControlNet-Travel loaded :)')
    except ImportError:
        print(f'{LOG_PREFIX} extension {CTRLNET_REPO_NAME} not found, ControlNet-Travel ignored :(')
        exit(0)
    except:
        print_exc()
        exit(0)

# ↑↑↑ EXIT EARLY IF EXTERNAL REPOSITORY NOT FOUND ↑↑↑


import sys
from PIL import Image

from ldm.models.diffusion.ddpm import LatentDiffusion
from modules import shared, devices, lowvram
from modules.processing import StableDiffusionProcessing as Processing

from scripts.prompt_travel import *
from manager import run_cmd

class InterpMethod(Enum):
    LINEAR = 'linear (weight sum)'
    RIFE   = 'rife (optical flow)'

if 'consts':
    __ = lambda key, value=None: opts.data.get(f'customscript/controlnet_travel.py/txt2img/{key}/value', value)


    LABEL_CTRLNET_REF_DIR   = 'Reference image folder (one ref image per stage :)'
    LABEL_INTERP_METH       = 'Interpolate method'
    LABEL_SKIP_FUSE         = 'Ext. skip latent fusion'
    LABEL_DEBUG_RIFE        = 'Save RIFE intermediates'

    DEFAULT_STEPS           = 10
    DEFAULT_CTRLNET_REF_DIR = str(ME_PATH / 'img' / 'ref_ctrlnet')
    DEFAULT_INTERP_METH     = __(LABEL_INTERP_METH, InterpMethod.LINEAR.value)
    DEFAULT_SKIP_FUSE       = __(LABEL_SKIP_FUSE, False)
    DEFAULT_DEBUG_RIFE      = __(LABEL_DEBUG_RIFE, False)

    CHOICES_INTERP_METH     = [x.value for x in InterpMethod]

if 'vars':
    skip_fuse_plan:       List[bool]         = []   # n_blocks (13)

    interp_alpha:         float              = 0.0
    interp_ip:            int                = 0    # 0 ~ n_sampling_step-1
    from_hint_cond:       List[Tensor]       = []   # n_contrlnet_set
    to_hint_cond:         List[Tensor]       = []
    mid_hint_cond:        List[Tensor]       = []
    from_control_tensors: List[List[Tensor]] = []   # n_sampling_step x n_blocks
    to_control_tensors:   List[List[Tensor]] = []

    caches: List[list] = [from_hint_cond, to_hint_cond, mid_hint_cond, from_control_tensors, to_control_tensors]


# ↓↓↓ the following is modified from 'sd-webui-controlnet/scripts/hook.py' ↓↓↓

def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_params:List[ControlParams], process:Processing):
    self.model = model
    self.sd_ldm = sd_ldm
    self.control_params = control_params

    outer = self

    def process_sample(*args, **kwargs):
        # ControlNet must know whether a prompt is conditional prompt (positive prompt) or unconditional conditioning prompt (negative prompt).
        # You can use the hook.py's `mark_prompt_context` to mark the prompts that will be seen by ControlNet.
        # Let us say XXX is a MulticondLearnedConditioning or a ComposableScheduledPromptConditioning or a ScheduledPromptConditioning or a list of these components,
        # if XXX is a positive prompt, you should call mark_prompt_context(XXX, positive=True)
        # if XXX is a negative prompt, you should call mark_prompt_context(XXX, positive=False)
        # After you mark the prompts, the ControlNet will know which prompt is cond/uncond and works as expected.
        # After you mark the prompts, the mismatch errors will disappear.
        mark_prompt_context(kwargs.get('conditioning', []), positive=True)
        mark_prompt_context(kwargs.get('unconditional_conditioning', []), positive=False)
        mark_prompt_context(getattr(process, 'hr_c', []), positive=True)
        mark_prompt_context(getattr(process, 'hr_uc', []), positive=False)
        return process.sample_before_CN_hack(*args, **kwargs)

    # NOTE: ↓↓↓ only hack this method ↓↓↓
    def forward(self:UNetModel, x:Tensor, timesteps:Tensor=None, context:Tensor=None, **kwargs):
        total_controlnet_embedding = [0.0] * 13
        total_t2i_adapter_embedding = [0.0] * 4
        require_inpaint_hijack = False
        is_in_high_res_fix = False
        batch_size = int(x.shape[0])

        # NOTE: declare globals
        global from_hint_cond, to_hint_cond, from_control_tensors, to_control_tensors, mid_hint_cond, interp_alpha, interp_ip
        x: Tensor           # [1, 4, 64, 64]
        timesteps: Tensor   # [1]
        context: Tensor     # [1, 78, 768]
        kwargs: dict        # {}

        # Handle cond-uncond marker
        cond_mark, outer.current_uc_indices, context = unmark_prompt_context(context)
        # logger.info(str(cond_mark[:, 0, 0, 0].detach().cpu().numpy().tolist()) + ' - ' + str(outer.current_uc_indices))

        # High-res fix
        for param in outer.control_params:
            # select which hint_cond to use
            if param.used_hint_cond is None:
                param.used_hint_cond = param.hint_cond      # NOTE: input hint cond tensor, [1, 3, 512, 512]
                param.used_hint_cond_latent = None
                param.used_hint_inpaint_hijack = None

            # has high-res fix
            if param.hr_hint_cond is not None and x.ndim == 4 and param.hint_cond.ndim == 4 and param.hr_hint_cond.ndim == 4:
                _, _, h_lr, w_lr = param.hint_cond.shape
                _, _, h_hr, w_hr = param.hr_hint_cond.shape
                _, _, h, w = x.shape
                h, w = h * 8, w * 8
                if abs(h - h_lr) < abs(h - h_hr):
                    is_in_high_res_fix = False
                    if param.used_hint_cond is not param.hint_cond:
                        param.used_hint_cond = param.hint_cond
                        param.used_hint_cond_latent = None
                        param.used_hint_inpaint_hijack = None
                else:
                    is_in_high_res_fix = True
                    if param.used_hint_cond is not param.hr_hint_cond:
                        param.used_hint_cond = param.hr_hint_cond
                        param.used_hint_cond_latent = None
                        param.used_hint_inpaint_hijack = None

        # NOTE: hint shallow fusion, overwrite param.used_hint_cond
        for i, param in enumerate(outer.control_params):
            if interp_alpha == 0.0:     # collect hind_cond on key frames
                if len(to_hint_cond) < len(outer.control_params):
                    to_hint_cond.append(param.used_hint_cond.clone().detach().cpu())
            else:                       # interp with cached hind_cond
                param.used_hint_cond = mid_hint_cond[i].to(x.device)

        # Convert control image to latent
        for param in outer.control_params:
            if param.used_hint_cond_latent is not None:
                continue
            if param.control_model_type not in [ControlModelType.AttentionInjection] \
                    and 'colorfix' not in param.preprocessor['name'] \
                    and 'inpaint_only' not in param.preprocessor['name']:
                continue
            param.used_hint_cond_latent = outer.call_vae_using_process(process, param.used_hint_cond, batch_size=batch_size)
                
        # handle prompt token control
        for param in outer.control_params:
            if param.guidance_stopped:
                continue

            if param.control_model_type not in [ControlModelType.T2I_StyleAdapter]:
                continue

            param.control_model.to(devices.get_device_for("controlnet"))
            control = param.control_model(x=x, hint=param.used_hint_cond, timesteps=timesteps, context=context)
            control = torch.cat([control.clone() for _ in range(batch_size)], dim=0)
            control *= param.weight
            control *= cond_mark[:, :, :, 0]
            context = torch.cat([context, control.clone()], dim=1)

        # handle ControlNet / T2I_Adapter
        for param in outer.control_params:
            if param.guidance_stopped:
                continue

            if param.control_model_type not in [ControlModelType.ControlNet, ControlModelType.T2I_Adapter]:
                continue

            param.control_model.to(devices.get_device_for("controlnet"))
            # inpaint model workaround
            x_in = x
            control_model = param.control_model.control_model

            if param.control_model_type == ControlModelType.ControlNet:
                if x.shape[1] != control_model.input_blocks[0][0].in_channels and x.shape[1] == 9:
                    # inpaint_model: 4 data + 4 downscaled image + 1 mask
                    x_in = x[:, :4, ...]
                    require_inpaint_hijack = True

            assert param.used_hint_cond is not None, f"Controlnet is enabled but no input image is given"

            hint = param.used_hint_cond

            # ControlNet inpaint protocol
            if hint.shape[1] == 4:
                c = hint[:, 0:3, :, :]
                m = hint[:, 3:4, :, :]
                m = (m > 0.5).float()
                hint = c * (1 - m) - m

            # NOTE: len(control) == 13, control[i]:Tensor
            control = param.control_model(x=x_in, hint=hint, timesteps=timesteps, context=context)
            control_scales = ([param.weight] * 13)

            if outer.lowvram:
                param.control_model.to("cpu")

            if param.cfg_injection or param.global_average_pooling:
                if param.control_model_type == ControlModelType.T2I_Adapter:
                    control = [torch.cat([c.clone() for _ in range(batch_size)], dim=0) for c in control]
                control = [c * cond_mark for c in control]

            high_res_fix_forced_soft_injection = False

            if is_in_high_res_fix:
                if 'canny' in param.preprocessor['name']:
                    high_res_fix_forced_soft_injection = True
                if 'mlsd' in param.preprocessor['name']:
                    high_res_fix_forced_soft_injection = True

            # if high_res_fix_forced_soft_injection:
            #     logger.info('[ControlNet] Forced soft_injection in high_res_fix in enabled.')

            if param.soft_injection or high_res_fix_forced_soft_injection:
                # important! use the soft weights with high-res fix can significantly reduce artifacts.
                if param.control_model_type == ControlModelType.T2I_Adapter:
                    control_scales = [param.weight * x for x in (0.25, 0.62, 0.825, 1.0)]
                elif param.control_model_type == ControlModelType.ControlNet:
                    control_scales = [param.weight * (0.825 ** float(12 - i)) for i in range(13)]

            if param.advanced_weighting is not None:
                control_scales = param.advanced_weighting

            control = [c * scale for c, scale in zip(control, control_scales)]
            if param.global_average_pooling:
                control = [torch.mean(c, dim=(2, 3), keepdim=True) for c in control]

            for idx, item in enumerate(control):
                target = None
                if param.control_model_type == ControlModelType.ControlNet:
                    target = total_controlnet_embedding
                if param.control_model_type == ControlModelType.T2I_Adapter:
                    target = total_t2i_adapter_embedding
                if target is not None:
                    target[idx] = item + target[idx]

        # Replace x_t to support inpaint models
        for param in outer.control_params:
            if param.used_hint_cond.shape[1] != 4:
                continue
            if x.shape[1] != 9:
                continue
            if param.used_hint_inpaint_hijack is None:
                mask_pixel = param.used_hint_cond[:, 3:4, :, :]
                image_pixel = param.used_hint_cond[:, 0:3, :, :]
                mask_pixel = (mask_pixel > 0.5).to(mask_pixel.dtype)
                masked_latent = outer.call_vae_using_process(process, image_pixel, batch_size, mask=mask_pixel)
                mask_latent = torch.nn.functional.max_pool2d(mask_pixel, (8, 8))
                if mask_latent.shape[0] != batch_size:
                    mask_latent = torch.cat([mask_latent.clone() for _ in range(batch_size)], dim=0)
                param.used_hint_inpaint_hijack = torch.cat([mask_latent, masked_latent], dim=1)
                param.used_hint_inpaint_hijack.to(x.dtype).to(x.device)
            x = torch.cat([x[:, :4, :, :], param.used_hint_inpaint_hijack], dim=1)

        # A1111 fix for medvram.
        if shared.cmd_opts.medvram:
            try:
                # Trigger the register_forward_pre_hook
                outer.sd_ldm.model()
            except:
                pass

        # Clear attention and AdaIn cache
        for module in outer.attn_module_list:
            module.bank = []
            module.style_cfgs = []
        for module in outer.gn_module_list:
            module.mean_bank = []
            module.var_bank = []
            module.style_cfgs = []

        # Handle attention and AdaIn control
        for param in outer.control_params:
            if param.guidance_stopped:
                continue

            if param.used_hint_cond_latent is None:
                continue

            if param.control_model_type not in [ControlModelType.AttentionInjection]:
                continue

            ref_xt = outer.sd_ldm.q_sample(param.used_hint_cond_latent, torch.round(timesteps.float()).long())

            # Inpaint Hijack
            if x.shape[1] == 9:
                ref_xt = torch.cat([
                    ref_xt,
                    torch.zeros_like(ref_xt)[:, 0:1, :, :],
                    param.used_hint_cond_latent
                ], dim=1)

            outer.current_style_fidelity = float(param.preprocessor['threshold_a'])
            outer.current_style_fidelity = max(0.0, min(1.0, outer.current_style_fidelity))

            if param.cfg_injection:
                outer.current_style_fidelity = 1.0
            elif param.soft_injection or is_in_high_res_fix:
                outer.current_style_fidelity = 0.0

            control_name = param.preprocessor['name']

            if control_name in ['reference_only', 'reference_adain+attn']:
                outer.attention_auto_machine = AutoMachine.Write
                outer.attention_auto_machine_weight = param.weight

            if control_name in ['reference_adain', 'reference_adain+attn']:
                outer.gn_auto_machine = AutoMachine.Write
                outer.gn_auto_machine_weight = param.weight

            outer.original_forward(
                x=ref_xt.to(devices.dtype_unet),
                timesteps=timesteps.to(devices.dtype_unet),
                context=context.to(devices.dtype_unet)
            )

            outer.attention_auto_machine = AutoMachine.Read
            outer.gn_auto_machine = AutoMachine.Read

        # NOTE: hint latent fusion, overwrite control tensors
        total_control = total_controlnet_embedding
        if interp_alpha == 0.0:     # collect control tensors on key frames
            tensors: List[Tensor] = []
            for i, t in enumerate(total_control):
                if len(skip_fuse_plan) and skip_fuse_plan[i]:
                    tensors.append(None)
                else:
                    tensors.append(t.clone().detach().cpu())
            to_control_tensors.append(tensors)
        else:                       # interp with cached control tensors
            device = total_control[0].device
            for i, (ctrlA, ctrlB) in enumerate(zip(from_control_tensors[interp_ip], to_control_tensors[interp_ip])):
                if ctrlA is not None and ctrlB is not None:
                    ctrlC = weighted_sum(ctrlA.to(device), ctrlB.to(device), interp_alpha)
                    #print('  ctrl diff:', (ctrlC - total_control[i]).abs().mean().item())
                    total_control[i].data = ctrlC
            interp_ip += 1
        
        # NOTE: warn on T2I adapter
        if total_t2i_adapter_embedding[0] != 0:
            print(f'{LOG_PREFIX} warn: currently t2i_adapter is not supported. if you wanna this, put a feature request on Kahsolt/stable-diffusion-webui-prompt-travel')

        # U-Net Encoder
        hs = []
        with th.no_grad():
            t_emb = cond_cast_unet(timestep_embedding(timesteps, self.model_channels, repeat_only=False))
            emb = self.time_embed(t_emb)
            h = x.type(self.dtype)
            for i, module in enumerate(self.input_blocks):
                h = module(h, emb, context)

                if (i + 1) % 3 == 0:
                    h = aligned_adding(h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack)

                hs.append(h)
            h = self.middle_block(h, emb, context)

        # U-Net Middle Block
        h = aligned_adding(h, total_controlnet_embedding.pop(), require_inpaint_hijack)

        # U-Net Decoder
        for i, module in enumerate(self.output_blocks):
            h = th.cat([h, aligned_adding(hs.pop(), total_controlnet_embedding.pop(), require_inpaint_hijack)], dim=1)
            h = module(h, emb, context)

        # U-Net Output
        h = h.type(x.dtype)
        h = self.out(h)

        # Post-processing for color fix
        for param in outer.control_params:
            if param.used_hint_cond_latent is None:
                continue
            if 'colorfix' not in param.preprocessor['name']:
                continue

            k = int(param.preprocessor['threshold_a'])
            if is_in_high_res_fix:
                k *= 2

            # Inpaint hijack
            xt = x[:, :4, :, :]

            x0_origin = param.used_hint_cond_latent
            t = torch.round(timesteps.float()).long()
            x0_prd = predict_start_from_noise(outer.sd_ldm, xt, t, h)
            x0 = x0_prd - blur(x0_prd, k) + blur(x0_origin, k)

            if '+sharp' in param.preprocessor['name']:
                detail_weight = float(param.preprocessor['threshold_b']) * 0.01
                neg = detail_weight * blur(x0, k) + (1 - detail_weight) * x0
                x0 = cond_mark * x0 + (1 - cond_mark) * neg

            eps_prd = predict_noise_from_start(outer.sd_ldm, xt, t, x0)

            w = max(0.0, min(1.0, float(param.weight)))
            h = eps_prd * w + h * (1 - w)

        # Post-processing for restore
        for param in outer.control_params:
            if param.used_hint_cond_latent is None:
                continue
            if 'inpaint_only' not in param.preprocessor['name']:
                continue
            if param.used_hint_cond.shape[1] != 4:
                continue

            # Inpaint hijack
            xt = x[:, :4, :, :]

            mask = param.used_hint_cond[:, 3:4, :, :]
            mask = torch.nn.functional.max_pool2d(mask, (10, 10), stride=(8, 8), padding=1)

            x0_origin = param.used_hint_cond_latent
            t = torch.round(timesteps.float()).long()
            x0_prd = predict_start_from_noise(outer.sd_ldm, xt, t, h)
            x0 = x0_prd * mask + x0_origin * (1 - mask)
            eps_prd = predict_noise_from_start(outer.sd_ldm, xt, t, x0)

            w = max(0.0, min(1.0, float(param.weight)))
            h = eps_prd * w + h * (1 - w)

        return h

    def forward_webui(*args, **kwargs):
        # webui will handle other compoments 
        try:
            if shared.cmd_opts.lowvram:
                lowvram.send_everything_to_cpu()

            return forward(*args, **kwargs)
        finally:
            if self.lowvram:
                for param in self.control_params:
                    if isinstance(param.control_model, torch.nn.Module):
                        param.control_model.to("cpu")

    def hacked_basic_transformer_inner_forward(self, x, context=None):
        x_norm1 = self.norm1(x)
        self_attn1 = None
        if self.disable_self_attn:
            # Do not use self-attention
            self_attn1 = self.attn1(x_norm1, context=context)
        else:
            # Use self-attention
            self_attention_context = x_norm1
            if outer.attention_auto_machine == AutoMachine.Write:
                if outer.attention_auto_machine_weight > self.attn_weight:
                    self.bank.append(self_attention_context.detach().clone())
                    self.style_cfgs.append(outer.current_style_fidelity)
            if outer.attention_auto_machine == AutoMachine.Read:
                if len(self.bank) > 0:
                    style_cfg = sum(self.style_cfgs) / float(len(self.style_cfgs))
                    self_attn1_uc = self.attn1(x_norm1, context=torch.cat([self_attention_context] + self.bank, dim=1))
                    self_attn1_c = self_attn1_uc.clone()
                    if len(outer.current_uc_indices) > 0 and style_cfg > 1e-5:
                        self_attn1_c[outer.current_uc_indices] = self.attn1(
                            x_norm1[outer.current_uc_indices],
                            context=self_attention_context[outer.current_uc_indices])
                    self_attn1 = style_cfg * self_attn1_c + (1.0 - style_cfg) * self_attn1_uc
                self.bank = []
                self.style_cfgs = []
            if self_attn1 is None:
                self_attn1 = self.attn1(x_norm1, context=self_attention_context)

        x = self_attn1.to(x.dtype) + x
        x = self.attn2(self.norm2(x), context=context) + x
        x = self.ff(self.norm3(x)) + x
        return x

    def hacked_group_norm_forward(self, *args, **kwargs):
        eps = 1e-6
        x = self.original_forward(*args, **kwargs)
        y = None
        if outer.gn_auto_machine == AutoMachine.Write:
            if outer.gn_auto_machine_weight > self.gn_weight:
                var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
                self.mean_bank.append(mean)
                self.var_bank.append(var)
                self.style_cfgs.append(outer.current_style_fidelity)
        if outer.gn_auto_machine == AutoMachine.Read:
            if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
                style_cfg = sum(self.style_cfgs) / float(len(self.style_cfgs))
                var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
                std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
                mean_acc = sum(self.mean_bank) / float(len(self.mean_bank))
                var_acc = sum(self.var_bank) / float(len(self.var_bank))
                std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
                y_uc = (((x - mean) / std) * std_acc) + mean_acc
                y_c = y_uc.clone()
                if len(outer.current_uc_indices) > 0 and style_cfg > 1e-5:
                    y_c[outer.current_uc_indices] = x.to(y_c.dtype)[outer.current_uc_indices]
                y = style_cfg * y_c + (1.0 - style_cfg) * y_uc
            self.mean_bank = []
            self.var_bank = []
            self.style_cfgs = []
        if y is None:
            y = x
        return y.to(x.dtype)

    if getattr(process, 'sample_before_CN_hack', None) is None:
        process.sample_before_CN_hack = process.sample
    process.sample = process_sample

    model._original_forward = model.forward
    outer.original_forward = model.forward
    model.forward = forward_webui.__get__(model, UNetModel)

    all_modules = torch_dfs(model)

    attn_modules = [module for module in all_modules if isinstance(module, BasicTransformerBlock)]
    attn_modules = sorted(attn_modules, key=lambda x: - x.norm1.normalized_shape[0])

    for i, module in enumerate(attn_modules):
        if getattr(module, '_original_inner_forward', None) is None:
            module._original_inner_forward = module._forward
        module._forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
        module.bank = []
        module.style_cfgs = []
        module.attn_weight = float(i) / float(len(attn_modules))

    gn_modules = [model.middle_block]
    model.middle_block.gn_weight = 0

    input_block_indices = [4, 5, 7, 8, 10, 11]
    for w, i in enumerate(input_block_indices):
        module = model.input_blocks[i]
        module.gn_weight = 1.0 - float(w) / float(len(input_block_indices))
        gn_modules.append(module)

    output_block_indices = [0, 1, 2, 3, 4, 5, 6, 7]
    for w, i in enumerate(output_block_indices):
        module = model.output_blocks[i]
        module.gn_weight = float(w) / float(len(output_block_indices))
        gn_modules.append(module)

    for i, module in enumerate(gn_modules):
        if getattr(module, 'original_forward', None) is None:
            module.original_forward = module.forward
        module.forward = hacked_group_norm_forward.__get__(module, torch.nn.Module)
        module.mean_bank = []
        module.var_bank = []
        module.style_cfgs = []
        module.gn_weight *= 2

    outer.attn_module_list = attn_modules
    outer.gn_module_list = gn_modules

    scripts.script_callbacks.on_cfg_denoiser(self.guidance_schedule_handler)

# ↑↑↑ the above is modified from 'sd-webui-controlnet/scripts/hook.py' ↑↑↑

def reset_cuda():
    devices.torch_gc()
    import gc; gc.collect()

    try:
        import os
        import psutil
        mem = psutil.Process(os.getpid()).memory_info()
        print(f'[Mem] rss: {mem.rss/2**30:.3f} GB, vms: {mem.vms/2**30:.3f} GB')
        from modules.shared import mem_mon as vram_mon
        free, total = vram_mon.cuda_mem_get_info()
        print(f'[VRAM] free: {free/2**30:.3f} GB, total: {total/2**30:.3f} GB')
    except:
        pass


class Script(scripts.Script):

    def title(self):
        return 'ControlNet Travel'

    def describe(self):
        return 'Travel from one controlnet hint condition to another in the tensor space.'

    def show(self, is_img2img):
        return controlnet_found

    def ui(self, is_img2img):
        with gr.Row(variant='compact'):
            interp_meth = gr.Dropdown(label=LABEL_INTERP_METH, value=lambda: DEFAULT_INTERP_METH, choices=CHOICES_INTERP_METH)
            steps       = gr.Text    (label=LABEL_STEPS,       value=lambda: DEFAULT_STEPS,       max_lines=1)
            
            reset = gr.Button(value='Reset Cuda', variant='tool')
            reset.click(fn=reset_cuda, show_progress=False)

        with gr.Row(variant='compact'):
            ctrlnet_ref_dir = gr.Text(label=LABEL_CTRLNET_REF_DIR, value=lambda: DEFAULT_CTRLNET_REF_DIR, max_lines=1)

        with gr.Group(visible=DEFAULT_SKIP_FUSE) as tab_ext_skip_fuse:
            with gr.Row(variant='compact'):
                skip_in_0  = gr.Checkbox(label='in_0')
                skip_in_3  = gr.Checkbox(label='in_3')
                skip_out_0 = gr.Checkbox(label='out_0')
                skip_out_3 = gr.Checkbox(label='out_3')
            with gr.Row(variant='compact'):
                skip_in_1  = gr.Checkbox(label='in_1')
                skip_in_4  = gr.Checkbox(label='in_4')
                skip_out_1 = gr.Checkbox(label='out_1')
                skip_out_4 = gr.Checkbox(label='out_4')
            with gr.Row(variant='compact'):
                skip_in_2  = gr.Checkbox(label='in_2')
                skip_in_5  = gr.Checkbox(label='in_5')
                skip_out_2 = gr.Checkbox(label='out_2')
                skip_out_5 = gr.Checkbox(label='out_5')
            with gr.Row(variant='compact'):
                skip_mid   = gr.Checkbox(label='mid')

        with gr.Row(variant='compact', visible=DEFAULT_UPSCALE) as tab_ext_upscale:
            upscale_meth   = gr.Dropdown(label=LABEL_UPSCALE_METH,   value=lambda: DEFAULT_UPSCALE_METH,   choices=CHOICES_UPSCALER)
            upscale_ratio  = gr.Slider  (label=LABEL_UPSCALE_RATIO,  value=lambda: DEFAULT_UPSCALE_RATIO,  minimum=1.0, maximum=16.0, step=0.1)
            upscale_width  = gr.Slider  (label=LABEL_UPSCALE_WIDTH,  value=lambda: DEFAULT_UPSCALE_WIDTH,  minimum=0,   maximum=2048, step=8)
            upscale_height = gr.Slider  (label=LABEL_UPSCALE_HEIGHT, value=lambda: DEFAULT_UPSCALE_HEIGHT, minimum=0,   maximum=2048, step=8)

        with gr.Row(variant='compact', visible=DEFAULT_VIDEO) as tab_ext_video:
            video_fmt  = gr.Dropdown(label=LABEL_VIDEO_FMT,  value=lambda: DEFAULT_VIDEO_FMT, choices=CHOICES_VIDEO_FMT)
            video_fps  = gr.Number  (label=LABEL_VIDEO_FPS,  value=lambda: DEFAULT_VIDEO_FPS)
            video_pad  = gr.Number  (label=LABEL_VIDEO_PAD,  value=lambda: DEFAULT_VIDEO_PAD,  precision=0)
            video_pick = gr.Text    (label=LABEL_VIDEO_PICK, value=lambda: DEFAULT_VIDEO_PICK, max_lines=1)

        with gr.Row(variant='compact') as tab_ext:
            ext_video     = gr.Checkbox(label=LABEL_VIDEO,      value=lambda: DEFAULT_VIDEO)
            ext_upscale   = gr.Checkbox(label=LABEL_UPSCALE,    value=lambda: DEFAULT_UPSCALE)
            ext_skip_fuse = gr.Checkbox(label=LABEL_SKIP_FUSE,  value=lambda: DEFAULT_SKIP_FUSE)
            dbg_rife      = gr.Checkbox(label=LABEL_DEBUG_RIFE, value=lambda: DEFAULT_DEBUG_RIFE)
        
            ext_video    .change(gr_show, inputs=ext_video,     outputs=tab_ext_video,     show_progress=False)
            ext_upscale  .change(gr_show, inputs=ext_upscale,   outputs=tab_ext_upscale,   show_progress=False)
            ext_skip_fuse.change(gr_show, inputs=ext_skip_fuse, outputs=tab_ext_skip_fuse, show_progress=False)

        skip_fuses = [
            skip_in_0,
            skip_in_1,
            skip_in_2,
            skip_in_3,
            skip_in_4,
            skip_in_5,
            skip_mid,
            skip_out_0,
            skip_out_1,
            skip_out_2,
            skip_out_3,
            skip_out_4,
            skip_out_5,
        ]
        return [
            interp_meth, steps, ctrlnet_ref_dir,
            upscale_meth, upscale_ratio, upscale_width, upscale_height,
            video_fmt, video_fps, video_pad, video_pick,
            ext_video, ext_upscale, ext_skip_fuse, dbg_rife,
            *skip_fuses,
        ]

    def run(self, p:Processing, 
            interp_meth:str, steps:str, ctrlnet_ref_dir:str, 
            upscale_meth:str, upscale_ratio:float, upscale_width:int, upscale_height:int,
            video_fmt:str, video_fps:float, video_pad:int, video_pick:str,
            ext_video:bool, ext_upscale:bool, ext_skip_fuse:bool, dbg_rife:bool,
            *skip_fuses:bool,
        ):

        # Prepare ControlNet
        #self.controlnet_script: ControlNetScript = None
        self.controlnet_script = None
        try:
            for script in p.scripts.alwayson_scripts:
                if hasattr(script, "latest_network") and script.title().lower() == "controlnet":
                    script_args: Tuple[ControlNetUnit] = p.script_args[script.args_from:script.args_to]
                    if not any([u.enabled for u in script_args]): return Processed(p, [], p.seed, f'{CTRLNET_REPO_NAME} not enabled')
                    self.controlnet_script = script
                    break
        except ImportError:
            return Processed(p, [], p.seed, f'{CTRLNET_REPO_NAME} not installed')
        except:
            print_exc()
        if not self.controlnet_script: return Processed(p, [], p.seed, f'{CTRLNET_REPO_NAME} not loaded')

        # Enum lookup
        interp_meth: InterpMethod = InterpMethod(interp_meth)
        video_fmt:   VideoFormat  = VideoFormat (video_fmt)

        # Param check & type convert
        if ext_video:
            if video_pad <  0: return Processed(p, [], p.seed, f'video_pad must >= 0, but got {video_pad}')
            if video_fps <= 0: return Processed(p, [], p.seed, f'video_fps must > 0, but got {video_fps}')
            try: video_slice = parse_slice(video_pick)
            except: return Processed(p, [], p.seed, 'syntax error in video_slice')
        if ext_skip_fuse:
            global skip_fuse_plan
            skip_fuse_plan = skip_fuses

        # Prepare ref-images
        if not ctrlnet_ref_dir: return Processed(p, [], p.seed, f'invalid image folder path: {ctrlnet_ref_dir}')
        ctrlnet_ref_dir: Path  = Path(ctrlnet_ref_dir)
        if not ctrlnet_ref_dir.is_dir(): return Processed(p, [], p.seed, f'invalid image folder path: {ctrlnet_ref_dir}(')
        self.ctrlnet_ref_fps = [fp for fp in list(ctrlnet_ref_dir.iterdir()) if fp.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp', '.webp']]
        n_stages = len(self.ctrlnet_ref_fps)
        if n_stages == 0: return Processed(p, [], p.seed, f'no images file (*.jpg/*.png/*.bmp/*.webp) found in folder path: {ctrlnet_ref_dir}')
        if n_stages == 1: return Processed(p, [], p.seed, 'requires at least two images to travel between, but found only 1 :(')

        # Prepare steps (n_interp)
        try: steps: List[int] = [int(s.strip()) for s in steps.strip().split(',')]
        except: return Processed(p, [], p.seed, f'cannot parse steps options: {steps}')
        if   len(steps) == 1: steps = [steps[0]] * (n_stages - 1)
        elif len(steps) != n_stages - 1: return Processed(p, [], p.seed, f'stage count mismatch: len_steps({len(steps)}) != n_stages({n_stages} - 1))')
        n_frames = sum(steps) + n_stages
        if 'show_debug':
            print('n_stages:', n_stages)
            print('n_frames:', n_frames)
            print('steps:', steps)
        steps.insert(0, -1)     # fixup the first stage

        # Custom saving path
        travel_path = os.path.join(p.outpath_samples, 'prompt_travel')
        os.makedirs(travel_path, exist_ok=True)
        travel_number = get_next_sequence_number(travel_path)
        self.log_dp = os.path.join(travel_path, f'{travel_number:05}')
        p.outpath_samples = self.log_dp
        os.makedirs(self.log_dp, exist_ok=True)
        self.tmp_dp = Path(self.log_dp) / 'ctrl_cond'   # cache for rife
        self.tmp_fp = self.tmp_dp / 'tmp.png'           # cache for rife

        # Force Batch Count and Batch Size to 1
        p.n_iter     = 1
        p.batch_size = 1

        # Random unified const seed
        p.seed = get_fixed_seed(p.seed)     # fix it to assure all processes using the same major seed
        self.subseed = p.subseed            # stash it to allow using random subseed for each process (when -1)
        if 'show_debug':
            print('seed:',             p.seed)
            print('subseed:',          p.subseed)
            print('subseed_strength:', p.subseed_strength)
        
        # Start job
        state.job_count = n_frames

        # Pack params
        self.n_stages    = n_stages
        self.steps       = steps
        self.interp_meth = interp_meth
        self.dbg_rife    = dbg_rife

        def upscale_image_callback(params:ImageSaveParams):
            params.image = upscale_image(params.image, p.width, p.height, upscale_meth, upscale_ratio, upscale_width, upscale_height)

        images: List[PILImage] = []
        info: str = None
        try:
            if ext_upscale: on_before_image_saved(upscale_image_callback)

            self.UnetHook_hook_original = UnetHook.hook
            UnetHook.hook = hook_hijack

            [c.clear() for c in caches]
            images, info = self.run_linear(p)
        except:
            info = format_exc()
            print(info)
        finally:
            if self.tmp_fp.exists(): os.unlink(self.tmp_fp)
            [c.clear() for c in caches]

            UnetHook.hook = self.UnetHook_hook_original

            self.controlnet_script.input_image = None
            if self.controlnet_script.latest_network:
                self.controlnet_script.latest_network: UnetHook
                self.controlnet_script.latest_network.restore(p.sd_model.model.diffusion_model)
                self.controlnet_script.latest_network = None

            if ext_upscale: remove_callbacks_for_function(upscale_image_callback)

            reset_cuda()

        # Save video
        if ext_video: save_video(images, video_slice, video_pad, video_fps, video_fmt, os.path.join(self.log_dp, f'travel-{travel_number:05}'))

        return Processed(p, images, p.seed, info)

    def run_linear(self, p:Processing) -> RunResults:
        global from_hint_cond, to_hint_cond, from_control_tensors, to_control_tensors, interp_alpha, interp_ip

        images: List[PILImage] = []
        info: str = None
        def process_p(append:bool=True) -> Optional[List[PILImage]]:
            nonlocal p, images, info
            proc = process_images(p)
            if not info: info = proc.info
            if append: images.extend(proc.images)
            else: return proc.images

        ''' ↓↓↓ rife interp utils ↓↓↓ '''
        def save_ctrl_cond(idx:int):
            self.tmp_dp.mkdir(exist_ok=True)
            for i, x in enumerate(to_hint_cond):
                x = x[0]
                if len(x.shape) == 3:
                    if   x.shape[0] == 1: x = x.squeeze_(0)         # [C=1, H, W] => [H, W]
                    elif x.shape[0] == 3: x = x.permute([1, 2, 0])  # [C=3, H, W] => [H, W, C]
                    else: raise ValueError(f'unknown cond shape: {x.shape}')
                else:
                    raise ValueError(f'unknown cond shape: {x.shape}')
                im = (x.detach().clamp(0.0, 1.0).cpu().numpy() * 255).astype(np.uint8)
                Image.fromarray(im).save(self.tmp_dp / f'{idx}-{i}.png')
        def rife_interp(i:int, j:int, k:int, alpha:float) -> Tensor:
            ''' interp between i-th and j-th cond of the k-th ctrlnet set '''
            fp0 = self.tmp_dp / f'{i}-{k}.png'
            fp1 = self.tmp_dp / f'{j}-{k}.png'
            fpo = self.tmp_dp / f'{i}-{j}-{alpha:.3f}.png' if self.dbg_rife else self.tmp_fp
            assert run_cmd(f'rife-ncnn-vulkan -m rife-v4 -s {alpha:.3f} -0 "{fp0}" -1 "{fp1}" -o "{fpo}"')
            x = torch.from_numpy(np.asarray(Image.open(fpo)) / 255.0)
            if   len(x.shape) == 2: x = x.unsqueeze_(0)             # [H, W] => [C=1, H, W]
            elif len(x.shape) == 3: x = x.permute([2, 0, 1])        # [H, W, C] => [C, H, W]
            else: raise ValueError(f'unknown cond shape: {x.shape}')
            x = x.unsqueeze(dim=0)
            return x
        ''' ↑↑↑ rife interp utils ↑↑↑ '''

        ''' ↓↓↓ filename reorder utils ↓↓↓ '''
        iframe = 0
        def rename_image_filename(idx:int, param: ImageSaveParams):
            fn = param.filename
            stem, suffix = os.path.splitext(os.path.basename(fn))
            param.filename = os.path.join(os.path.dirname(fn), f'{idx:05d}' + suffix)
        class on_before_image_saved_wrapper:
            def __init__(self, callback_fn):
                self.callback_fn = callback_fn
            def __enter__(self):
                on_before_image_saved(self.callback_fn)
            def __exit__(self, exc_type, exc_value, exc_traceback):
                remove_callbacks_for_function(self.callback_fn)
        ''' ↑↑↑ filename reorder utils ↑↑↑ '''

        # Step 1: draw the init image
        setattr(p, 'init_images', [Image.open(self.ctrlnet_ref_fps[0])])
        interp_alpha = 0.0
        with on_before_image_saved_wrapper(partial(rename_image_filename, 0)):
            process_p()
            iframe += 1
        save_ctrl_cond(0)

        # travel through stages
        for i in range(1, self.n_stages):
            if state.interrupted: break

            # Setp 3: move to next stage
            from_hint_cond       = [t for t in to_hint_cond]       ; to_hint_cond      .clear()
            from_control_tensors = [t for t in to_control_tensors] ; to_control_tensors.clear()
            setattr(p, 'init_images', [Image.open(self.ctrlnet_ref_fps[i])])
            interp_alpha = 0.0

            with on_before_image_saved_wrapper(partial(rename_image_filename, iframe + self.steps[i])):
                cached_images = process_p(append=False)
            save_ctrl_cond(i)

            # Step 2: draw the interpolated images
            is_interrupted = False
            n_inter = self.steps[i] + 1
            for t in range(1, n_inter):
                if state.interrupted: is_interrupted = True ; break

                interp_alpha = t / n_inter     # [1/T, 2/T, .. T-1/T]

                mid_hint_cond.clear()
                device = devices.get_device_for("controlnet")
                if self.interp_meth == InterpMethod.LINEAR:
                    for hintA, hintB in zip(from_hint_cond, to_hint_cond):
                        hintC = weighted_sum(hintA.to(device), hintB.to(device), interp_alpha)
                        mid_hint_cond.append(hintC)
                elif self.interp_meth == InterpMethod.RIFE:
                    dtype = to_hint_cond[0].dtype
                    for k in range(len(to_hint_cond)):
                        hintC = rife_interp(i-1, i, k, interp_alpha).to(device, dtype)
                        mid_hint_cond.append(hintC)
                else: raise ValueError(f'unknown interp_meth: {self.interp_meth}')

                interp_ip = 0
                with on_before_image_saved_wrapper(partial(rename_image_filename, iframe)):
                    process_p()
                    iframe += 1

            # adjust order
            images.extend(cached_images)
            iframe += 1

            if is_interrupted: break

        return images, info