File size: 46,087 Bytes
82ea528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
from typing import Callable

import math
import torch
from torch import Tensor
from torch.nn.functional import group_norm
from einops import rearrange

import comfy.model_management
import comfy.model_patcher
import comfy.patcher_extension
import comfy.samplers
import comfy.sampler_helpers
import comfy.utils
from comfy.controlnet import ControlBase
from comfy.model_base import BaseModel
from comfy.model_patcher import ModelPatcher
from comfy.patcher_extension import WrapperExecutor, WrappersMP
import comfy.conds
import comfy.ops

from .context import ContextFuseMethod, ContextSchedules, get_context_weights, get_context_windows
from .context_extras import ContextRefMode
from .sample_settings import SampleSettings, NoisedImageToInject
from .utils_model import MachineState, vae_encode_raw_batched, vae_decode_raw_batched
from .utils_motion import composite_extend, prepare_mask_batch, extend_to_batch_size
from .model_injection import InjectionParams, ModelPatcherHelper, MotionModelGroup, get_mm_attachment
from .motion_module_ad import AnimateDiffFormat, AnimateDiffInfo, AnimateDiffVersion
from .logger import logger


##################################################################################
######################################################################
# Global variable to use to more conveniently hack variable access into samplers
class AnimateDiffGlobalState:
    def __init__(self):
        self.model_patcher: ModelPatcher = None
        self.motion_models: MotionModelGroup = None
        self.params: InjectionParams = None
        self.sample_settings: SampleSettings = None
        self.callback_output_dict: dict[str] = {}
        self.function_injections: FunctionInjectionHolder = None
        self.reset()

    def initialize(self, model: BaseModel):
        # this function is to be run in sampling func
        if not self.initialized:
            self.initialized = True
            if self.motion_models is not None:
                self.motion_models.initialize_timesteps(model)
            if self.params.context_options is not None:
                self.params.context_options.initialize_timesteps(model)
            if self.sample_settings.custom_cfg is not None:
                self.sample_settings.custom_cfg.initialize_timesteps(model)

    def prepare_current_keyframes(self, x: Tensor, timestep: Tensor):
        if self.motion_models is not None:
            self.motion_models.prepare_current_keyframe(x=x, t=timestep)
        if self.params.context_options is not None:
            self.params.context_options.prepare_current(t=timestep)
        if self.sample_settings.custom_cfg is not None:
            self.sample_settings.custom_cfg.prepare_current_keyframe(t=timestep)

    def perform_special_model_features(self, model: BaseModel, conds: list, x_in: Tensor, model_options: dict[str]):
        if self.motion_models is not None:
            special_models = self.motion_models.get_special_models()
            if len(special_models) > 0:
                for special_model in special_models:
                    if special_model.model.is_in_effect():
                        attachment = get_mm_attachment(special_model)
                        if attachment.is_pia(special_model):
                            special_model.model.inject_unet_conv_in_pia_fancyvideo(model)
                            conds = get_conds_with_c_concat(conds,
                                                            attachment.get_pia_c_concat(model, x_in))
                        elif attachment.is_fancyvideo(special_model):
                            # TODO: handle other weights
                            special_model.model.inject_unet_conv_in_pia_fancyvideo(model)
                            conds = get_conds_with_c_concat(conds,
                                                            attachment.get_fancy_c_concat(model, x_in))
                            # add fps_embedding/motion_embedding patches
                            emb_patches = special_model.model.get_fancyvideo_emb_patches(dtype=x_in.dtype, device=x_in.device)
                            transformer_patches = model_options["transformer_options"].get("patches", {})
                            transformer_patches["emb_patch"] = emb_patches
                            model_options["transformer_options"]["patches"] = transformer_patches
        return conds

    def restore_special_model_features(self, model: BaseModel):
        if self.motion_models is not None:
            special_models = self.motion_models.get_special_models()
            if len(special_models) > 0:
                for special_model in reversed(special_models):
                    attachment = get_mm_attachment(special_model)
                    if attachment.is_pia(special_model):
                        special_model.model.restore_unet_conv_in_pia_fancyvideo(model)
                    elif attachment.is_fancyvideo(special_model):
                        # TODO: fill out
                        special_model.model.restore_unet_conv_in_pia_fancyvideo(model)

    def reset(self):
        self.initialized = False
        self.hooks_initialized = False
        self.start_step: int = 0
        self.last_step: int = 0
        self.current_step: int = 0
        self.total_steps: int = 0
        self.callback_output_dict.clear()
        self.callback_output_dict = {}
        if self.model_patcher is not None:
            self.model_patcher.clean_hooks()
            del self.model_patcher
            self.model_patcher = None
        if self.motion_models is not None:
            del self.motion_models
            self.motion_models = None
        if self.params is not None:
            self.params.context_options.reset()
            del self.params
            self.params = None
        if self.sample_settings is not None:
            del self.sample_settings
            self.sample_settings = None
        if self.function_injections is not None:
            del self.function_injections
            self.function_injections = None

    def update_with_inject_params(self, params: InjectionParams):
        self.params = params

    def is_using_sliding_context(self):
        return self.params is not None and self.params.is_using_sliding_context()

    def create_exposed_params(self):
        # This dict will be exposed to be used by other extensions
        # DO NOT change any of the key names
        # or I will find you 👁.👁
        return {
            "full_length": self.params.full_length,
            "context_length": self.params.context_options.context_length,
            "sub_idxs": self.params.sub_idxs,
        }
######################################################################
##################################################################################


##################################################################################
#### Code Injection ##################################################
def unlimited_memory_required(*args, **kwargs):
    return 0


def groupnorm_mm_factory(params: InjectionParams, manual_cast=False):
    def groupnorm_mm_forward(self, input: Tensor) -> Tensor:
        # axes_factor normalizes batch based on total conds and unconds passed in batch;
        # the conds and unconds per batch can change based on VRAM optimizations that may kick in
        if not params.is_using_sliding_context():
            batched_conds = input.size(0)//params.full_length
        else:
            batched_conds = input.size(0)//params.context_options.context_length

        input = rearrange(input, "(b f) c h w -> b c f h w", b=batched_conds)
        if manual_cast:
            weight, bias = comfy.ops.cast_bias_weight(self, input)
        else:
            weight, bias = self.weight, self.bias
        input = group_norm(input, self.num_groups, weight, bias, self.eps)
        input = rearrange(input, "b c f h w -> (b f) c h w", b=batched_conds)
        return input
    return groupnorm_mm_forward


def create_special_model_apply_model_wrapper(model_options: dict):
    comfy.patcher_extension.add_wrapper_with_key(WrappersMP.APPLY_MODEL,
                                                 "ADE_special_model_apply_model",
                                                 _apply_model_wrapper,
                                                 model_options, is_model_options=True)

def _apply_model_wrapper(executor, *args, **kwargs):
    # args (from BaseModel._apply_model):
    # 0: x
    # 1: t
    # 2: c_concat
    # 3: c_crossattn
    # 4: control
    # 5: transformer_options
    x: Tensor = args[0]
    transformer_options = args[5]
    cond_or_uncond = transformer_options["cond_or_uncond"]
    ad_params = transformer_options["ad_params"]
    ADGS: AnimateDiffGlobalState = transformer_options["ADGS"]
    if ADGS.motion_models is not None:
            for motion_model in ADGS.motion_models.models:
                attachment = get_mm_attachment(motion_model)
                attachment.prepare_alcmi2v_features(motion_model, x=x, cond_or_uncond=cond_or_uncond, ad_params=ad_params, latent_format=executor.class_obj.latent_format)
                attachment.prepare_camera_features(motion_model, x=x, cond_or_uncond=cond_or_uncond, ad_params=ad_params)
    del x
    return executor(*args, **kwargs)


def create_diffusion_model_groupnormed_wrapper(model_options: dict, inject_helper: 'GroupnormInjectHelper'):
    comfy.patcher_extension.add_wrapper_with_key(WrappersMP.DIFFUSION_MODEL,
                                                 "ADE_groupnormed_diffusion_model",
                                                 _diffusion_model_groupnormed_wrapper_factory(inject_helper),
                                                 model_options, is_model_options=True)

def _diffusion_model_groupnormed_wrapper_factory(inject_helper: 'GroupnormInjectHelper'):
    def _diffusion_model_groupnormed_wrapper(executor, *args, **kwargs):
        with inject_helper:
            return executor(*args, **kwargs)
    return _diffusion_model_groupnormed_wrapper
######################################################################
##################################################################################


def apply_params_to_motion_models(helper: ModelPatcherHelper, params: InjectionParams):
    params = params.clone()
    for context in params.context_options.contexts:
        if context.context_schedule == ContextSchedules.VIEW_AS_CONTEXT:
            context.context_length = params.full_length
    # TODO: check (and message) should be different based on use_on_equal_length setting
    if params.context_options.context_length:
        pass

    allow_equal = params.context_options.use_on_equal_length
    if params.context_options.context_length:
        enough_latents = params.full_length >= params.context_options.context_length if allow_equal else params.full_length > params.context_options.context_length
    else:
        enough_latents = False
    if params.context_options.context_length and enough_latents:
        logger.info(f"Sliding context window sampling activated - latents passed in ({params.full_length}) greater than context_length {params.context_options.context_length}.")
    else:
        logger.info(f"Regular sampling activated - latents passed in ({params.full_length}) less or equal to context_length {params.context_options.context_length}.")
        params.reset_context()
    if helper.get_motion_models():
        # if no context_length, treat video length as intended AD frame window
        if not params.context_options.context_length:
            for motion_model in helper.get_motion_models():
                if not motion_model.model.is_length_valid_for_encoding_max_len(params.full_length):
                    raise ValueError(f"Without a context window, AnimateDiff model {motion_model.model.mm_info.mm_name} has upper limit of {motion_model.model.encoding_max_len} frames, but received {params.full_length} latents.")
            helper.set_video_length(params.full_length, params.full_length)
        # otherwise, treat context_length as intended AD frame window
        else:
            for motion_model in helper.get_motion_models():
                view_options = params.context_options.view_options
                context_length = view_options.context_length if view_options else params.context_options.context_length
                if not motion_model.model.is_length_valid_for_encoding_max_len(context_length):
                    raise ValueError(f"AnimateDiff model {motion_model.model.mm_info.mm_name} has upper limit of {motion_model.model.encoding_max_len} frames for a context window, but received context length of {params.context_options.context_length}.")
            helper.set_video_length(params.context_options.context_length, params.full_length)
        # inject model
        module_str = "modules" if len(helper.get_motion_models()) > 1 else "module"
        logger.info(f"Using motion {module_str} {helper.get_name_string(show_version=True)}.")
    return params


class FunctionInjectionHolder:
    def __init__(self):
        self.temp_uninjector: GroupnormUninjectHelper = GroupnormUninjectHelper()
        self.groupnorm_injector: GroupnormInjectHelper = GroupnormInjectHelper()
    
    def inject_functions(self, helper: ModelPatcherHelper, params: InjectionParams, model_options: dict):
        # Save Original Functions - order must match between here and restore_functions
        self.orig_memory_required = None
        self.orig_groupnorm_forward = torch.nn.GroupNorm.forward # used to normalize latents to remove "flickering" of colors/brightness between frames
        self.orig_groupnorm_forward_comfy_cast_weights = comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights
        self.orig_sampling_function = comfy.samplers.sampling_function # used to support sliding context windows in samplers
        # Inject Functions
        if params.unlimited_area_hack:
            # allows for "unlimited area hack" to prevent halving of conds/unconds
            self.orig_memory_required = helper.model.model.memory_required
            helper.model.model.memory_required = unlimited_memory_required
        if helper.get_motion_models():
            # only apply groupnorm hack if PIA, v2 and not properly applied, or v1
            info: AnimateDiffInfo = helper.get_motion_models()[0].model.mm_info
            if ((info.mm_format == AnimateDiffFormat.PIA) or
                (info.mm_version == AnimateDiffVersion.V2 and not params.apply_v2_properly) or
                (info.mm_version == AnimateDiffVersion.V1)):
                self.inject_groupnorm_forward = groupnorm_mm_factory(params)
                self.inject_groupnorm_forward_comfy_cast_weights = groupnorm_mm_factory(params, manual_cast=True)
                self.groupnorm_injector = GroupnormInjectHelper(self)
                create_diffusion_model_groupnormed_wrapper(model_options, self.groupnorm_injector)
                # if mps device (Apple Silicon), disable batched conds to avoid black images with groupnorm hack
                try:
                    if helper.model.load_device.type == "mps":
                        self.orig_memory_required = helper.model.model.memory_required
                        helper.model.model.memory_required = unlimited_memory_required
                except Exception:
                    pass
            # if img_encoder or camera_encoder present, inject apply_model to handle correctly
            for motion_model in helper.get_motion_models():
                if (motion_model.model.img_encoder is not None) or (motion_model.model.camera_encoder is not None):
                    create_special_model_apply_model_wrapper(model_options)
                    break
            del info
        comfy.samplers.sampling_function = evolved_sampling_function
        # create temp_uninjector to help facilitate uninjecting functions
        self.temp_uninjector = GroupnormUninjectHelper(self)

    def restore_functions(self, helper: ModelPatcherHelper):
        # Restoration
        try:
            if self.orig_memory_required is not None:
                helper.model.model.memory_required = self.orig_memory_required
            torch.nn.GroupNorm.forward = self.orig_groupnorm_forward
            comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights = self.orig_groupnorm_forward_comfy_cast_weights
            comfy.samplers.sampling_function = self.orig_sampling_function
        except AttributeError:
            logger.error("Encountered AttributeError while attempting to restore functions - likely, an error occured while trying " + \
                         "to save original functions before injection, and a more specific error was thrown by ComfyUI.")


class GroupnormUninjectHelper:
    def __init__(self, holder: FunctionInjectionHolder=None):
        self.holder = holder
        self.previous_gn_forward = None
        self.previous_dwi_gn_cast_weights = None
    
    def __enter__(self):
        if self.holder is None:
            return self
        # backup current groupnorm funcs
        self.previous_gn_forward = torch.nn.GroupNorm.forward
        self.previous_dwi_gn_cast_weights = comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights
        # restore groupnorm to default state
        torch.nn.GroupNorm.forward = self.holder.orig_groupnorm_forward
        comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights = self.holder.orig_groupnorm_forward_comfy_cast_weights
        return self

    def __exit__(self, *args, **kwargs):
        if self.holder is None:
            return
        # bring groupnorm back to previous state
        torch.nn.GroupNorm.forward = self.previous_gn_forward
        comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights = self.previous_dwi_gn_cast_weights
        self.previous_gn_forward = None
        self.previous_dwi_gn_cast_weights = None


class GroupnormInjectHelper:
    def __init__(self, holder: FunctionInjectionHolder=None):
        self.holder = holder
        self.previous_gn_forward = None
        self.previous_dwi_gn_cast_weights = None
    
    def __enter__(self):
        if self.holder is None:
            return self
        # store previous gn_forward
        self.previous_gn_forward = torch.nn.GroupNorm.forward
        self.previous_dwi_gn_cast_weights = comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights
        # inject groupnorm functions
        torch.nn.GroupNorm.forward = self.holder.inject_groupnorm_forward
        comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights = self.holder.inject_groupnorm_forward_comfy_cast_weights
        return self

    def __exit__(self, *args, **kwargs):
        if self.holder is None:
            return
        # bring groupnorm back to previous state
        torch.nn.GroupNorm.forward = self.previous_gn_forward
        comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights = self.previous_dwi_gn_cast_weights
        self.previous_gn_forward = None
        self.previous_dwi_gn_cast_weights = None     


def outer_sample_wrapper(executor: WrapperExecutor, *args, **kwargs):
    # NOTE: OUTER_SAMPLE wrapper patch in ModelPatcher
    latents = None
    cached_latents = None
    cached_noise = None
    function_injections = FunctionInjectionHolder()

    try:
        guider: comfy.samplers.CFGGuider = executor.class_obj
        helper = ModelPatcherHelper(guider.model_patcher)

        orig_model_options = guider.model_options
        guider.model_options = comfy.model_patcher.create_model_options_clone(guider.model_options)
        # create ADGS in transformer_options
        ADGS = AnimateDiffGlobalState()
        guider.model_options["transformer_options"]["ADGS"] = ADGS

        args = list(args)
        # clone params from model
        params = helper.get_params().clone()
        # get amount of latents passed in, and store in params
        noise: Tensor = args[0]
        latents: Tensor = args[1]
        params.full_length = latents.size(0)
        # reset global state
        ADGS.reset()

        # apply custom noise, if needed
        disable_noise = math.isclose(noise.max(), 0.0)
        seed = args[-1]

        # apply params to motion model
        params = apply_params_to_motion_models(helper, params)

        # store and inject funtions
        function_injections.inject_functions(helper, params, guider.model_options)

        # prepare noise_extra_args for noise generation purposes
        noise_extra_args = {"disable_noise": disable_noise}
        params.set_noise_extra_args(noise_extra_args)
        # if noise is not disabled, do noise stuff
        if not disable_noise:
            noise = helper.get_sample_settings().prepare_noise(seed, latents, noise, extra_args=noise_extra_args, force_create_noise=False)

        # callback setup
        original_callback = args[-3]
        def ad_callback(step, x0, x, total_steps):
            if original_callback is not None:
                original_callback(step, x0, x, total_steps)
            # store denoised latents if image_injection will be used
            if not helper.get_sample_settings().image_injection.is_empty():
                ADGS.callback_output_dict["x0"] = x0
            # update GLOBALSTATE for next iteration
            ADGS.current_step = ADGS.start_step + step + 1
        args[-3] = ad_callback
        ADGS.model_patcher = helper.model
        ADGS.motion_models = MotionModelGroup(helper.get_motion_models())
        ADGS.sample_settings = helper.get_sample_settings()
        ADGS.function_injections = function_injections

        # apply adapt_denoise_steps - does not work here! would need to mess with this elsewhere...
        # TODO: implement proper wrapper to handle this feature...

        iter_opts = helper.get_sample_settings().iteration_opts
        iter_opts.initialize(latents)
        # cache initial noise and latents, if needed
        if iter_opts.cache_init_latents:
            cached_latents = latents.clone()
        if iter_opts.cache_init_noise:
            cached_noise = noise.clone()
        # prepare iter opts preprocess kwargs, if needed
        iter_kwargs = {}
        # NOTE: original KSampler stuff is not doable here, so skipping...

        for curr_i in range(iter_opts.iterations):
            # handle GLOBALSTATE vars and step tally
            # NOTE: only KSampler/KSampler (Advanced) would have steps;
            # explore modifying ComfyUI to provide this when possible?
            ADGS.update_with_inject_params(params)
            ADGS.start_step = kwargs.get("start_step") or 0
            ADGS.current_step = ADGS.start_step
            ADGS.last_step = kwargs.get("last_step") or 0
            if iter_opts.iterations > 1:
                logger.info(f"Iteration {curr_i+1}/{iter_opts.iterations}")
            # perform any iter_opts preprocessing on latents
            latents, noise = iter_opts.preprocess_latents(curr_i=curr_i, model=helper.model, latents=latents, noise=noise,
                                                          cached_latents=cached_latents, cached_noise=cached_noise,
                                                          seed=seed,
                                                          sample_settings=helper.get_sample_settings(), noise_extra_args=noise_extra_args,
                                                          **iter_kwargs)
            if helper.get_sample_settings().noise_calibration is not None:
                    latents, noise = helper.get_sample_settings().noise_calibration.perform_calibration(sample_func=executor, model=helper.model, latents=latents, noise=noise,
                                                                                                 is_custom=True, args=args, kwargs=kwargs)
            # finalize latent_image in args
            args[0] = noise
            args[1] = latents

            helper.pre_run()

            if ADGS.sample_settings.image_injection.is_empty():
                latents = executor(*tuple(args), **kwargs)
            else:
                ADGS.sample_settings.image_injection.initialize_timesteps(helper.model.model)
                sigmas = args[3]
                sigmas_list, injection_list = ADGS.sample_settings.image_injection.custom_ksampler_get_injections(helper.model, sigmas)
                # useful logging
                if len(injection_list) > 0:
                    inj_str = "s" if len(injection_list) > 1 else ""
                    logger.info(f"Found {len(injection_list)} applicable image injection{inj_str}; sampling will be split into {len(sigmas_list)}.")
                else:
                    logger.info(f"Found 0 applicable image injections within the step bounds of this sampler; sampling unaffected.")
                is_first = True
                new_noise = noise
                for i in range(len(sigmas_list)):
                    args[0] = new_noise
                    args[1] = latents
                    args[3] = sigmas_list[i]
                    latents = executor(*tuple(args), **kwargs)
                    if is_first:
                        new_noise = torch.zeros_like(latents)
                    # if injection expected, perform injection
                    if i < len(injection_list):
                        to_inject = injection_list[i]
                        latents = perform_image_injection(ADGS, helper.model.model, latents, to_inject)
        return latents
    finally:
        guider.model_options = orig_model_options
        del noise
        del latents
        del cached_latents
        del cached_noise
        del orig_model_options
        # reset global state
        ADGS.reset()
        # clean motion_models
        helper.cleanup_motion_models()
        # restore injected functions
        function_injections.restore_functions(helper)
        del function_injections
        del helper


def evolved_sampling_function(model, x: Tensor, timestep: Tensor, uncond, cond, cond_scale, model_options: dict={}, seed=None):
    ADGS: AnimateDiffGlobalState = model_options["transformer_options"]["ADGS"]
    ADGS.initialize(model)
    ADGS.prepare_current_keyframes(x=x, timestep=timestep)
    try:
        # add AD/evolved-sampling params to model_options (transformer_options)
        model_options = model_options.copy()
        if "transformer_options" not in model_options:
            model_options["transformer_options"] = {}
        else:
            model_options["transformer_options"] = model_options["transformer_options"].copy()
        model_options["transformer_options"]["ad_params"] = ADGS.create_exposed_params()

        cond, uncond = ADGS.perform_special_model_features(model, [cond, uncond], x, model_options)

        # only use cfg1_optimization if not using custom_cfg or explicitly set to 1.0
        uncond_ = uncond
        if ADGS.sample_settings.custom_cfg is None and math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
            uncond_ = None
        elif ADGS.sample_settings.custom_cfg is not None:
            cfg_multival = ADGS.sample_settings.custom_cfg.cfg_multival
            if type(cfg_multival) != Tensor and math.isclose(cfg_multival, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
                uncond_ = None
            del cfg_multival

        cond_pred, uncond_pred = comfy.samplers.calc_cond_batch(model, [cond, uncond_], x, timestep, model_options)

        if ADGS.sample_settings.custom_cfg is not None:
            cond_scale = ADGS.sample_settings.custom_cfg.get_cfg_scale(cond_pred)
            model_options = ADGS.sample_settings.custom_cfg.get_model_options(model_options)
        
        return comfy.samplers.cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options, cond, uncond)
    finally:
        ADGS.restore_special_model_features(model)


def perform_image_injection(ADGS: AnimateDiffGlobalState, model: BaseModel, latents: Tensor, to_inject: NoisedImageToInject) -> Tensor:
    # NOTE: the latents here have already been process_latent_out'ed
    # get currently used models so they can be properly reloaded after perfoming VAE Encoding
    cached_loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
    try:
        orig_device = latents.device
        orig_dtype = latents.dtype
        # follow same steps as in KSampler Custom to get same denoised_x0 value
        x0 = ADGS.callback_output_dict.get("x0", None)
        if x0 is None:
            return latents
        # x0 should be process_latent_out'ed to match expected state of latents between nodes
        x0 = model.process_latent_out(x0)
    
        # first, decode x0 into images, and then re-encode
        decoded_images = vae_decode_raw_batched(to_inject.vae, x0)
        encoded_x0 = vae_encode_raw_batched(to_inject.vae, decoded_images)

        # get difference between sampled latents and encoded_x0
        latents = latents.to(device=encoded_x0.device)
        encoded_x0 = latents - encoded_x0

        # get mask, or default to full mask
        mask = to_inject.mask
        b, c, h, w = encoded_x0.shape
        # need to resize images and masks to match expected dims
        if mask is None:
            mask = torch.ones(1, h, w)
        if to_inject.invert_mask:
            mask = 1.0 - mask
        opts = to_inject.img_inject_opts
        # composite decoded_x0 with image to inject;
        # make sure to move dims to match expectation of (b,c,h,w)
        composited = composite_extend(destination=decoded_images.movedim(-1, 1), source=to_inject.image.movedim(-1, 1), x=opts.x, y=opts.y, mask=mask,
                                      multiplier=to_inject.vae.downscale_ratio, resize_source=to_inject.resize_image).movedim(1, -1)
        # encode composited to get latent representation
        composited = vae_encode_raw_batched(to_inject.vae, composited)
        # add encoded_x0 diff to composited 
        composited += encoded_x0
        if type(to_inject.strength_multival) == float and math.isclose(1.0, to_inject.strength_multival):
            return composited.to(dtype=orig_dtype, device=orig_device)
        strength = to_inject.strength_multival
        if type(strength) == Tensor:
            strength = extend_to_batch_size(prepare_mask_batch(strength, composited.shape), b)
        return (composited * strength + latents * (1.0 - strength)).to(dtype=orig_dtype, device=orig_device)
    finally:
        comfy.model_management.load_models_gpu(cached_loaded_models)


# initial sliding_calc_conds_batch inspired by ashen's initial hack for 16-frame sliding context:
# https://github.com/comfyanonymous/ComfyUI/compare/master...ashen-sensored:ComfyUI:master
def sliding_calc_cond_batch(executor: Callable, model, conds: list[list[dict]], x_in: Tensor, timestep, model_options):
    ADGS: AnimateDiffGlobalState = model_options["transformer_options"]["ADGS"]
    if not ADGS.is_using_sliding_context():
        return executor(model, conds, x_in, timestep, model_options)

    def prepare_control_objects(control: ControlBase, full_idxs: list[int]):
        if control.previous_controlnet is not None:
            prepare_control_objects(control.previous_controlnet, full_idxs)
        if not hasattr(control, "sub_idxs"):
            raise ValueError(f"Control type {type(control).__name__} may not support required features for sliding context window; \
                                use ControlNet nodes from Kosinkadink/ComfyUI-Advanced-ControlNet, or make sure ComfyUI-Advanced-ControlNet is updated.")
        control.sub_idxs = full_idxs
        control.full_latent_length = ADGS.params.full_length
        control.context_length = ADGS.params.context_options.context_length
    
    def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list:
        if cond_in is None:
            return None
        # reuse or resize cond items to match context requirements
        resized_cond = []
        # cond object is a list containing a dict - outer list is irrelevant, so just loop through it
        for actual_cond in cond_in:
            resized_actual_cond = actual_cond.copy()
            # now we are in the inner dict - "pooled_output" is a tensor, "control" is a ControlBase object, "model_conds" is dictionary
            for key in actual_cond:
                try:
                    cond_item = actual_cond[key]
                    if isinstance(cond_item, Tensor):
                        # check that tensor is the expected length - x.size(0)
                        if cond_item.size(0) == x_in.size(0):
                            # if so, it's subsetting time - tell controls the expected indeces so they can handle them
                            actual_cond_item = cond_item[full_idxs]
                            resized_actual_cond[key] = actual_cond_item
                        else:
                            resized_actual_cond[key] = cond_item
                    # look for control
                    elif key == "control":
                        control_item = cond_item
                        prepare_control_objects(control_item, full_idxs)
                        resized_actual_cond[key] = control_item
                        del control_item
                    elif isinstance(cond_item, dict):
                        new_cond_item = cond_item.copy()
                        # when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
                        for cond_key, cond_value in new_cond_item.items():
                            if isinstance(cond_value, Tensor):
                                if cond_value.size(0) == x_in.size(0):
                                    new_cond_item[cond_key] = cond_value[full_idxs]
                            # if has cond that is a Tensor, check if needs to be subset
                            elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, Tensor):
                                if cond_value.cond.size(0) == x_in.size(0):
                                    new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond[full_idxs])
                            elif cond_key == "num_video_frames": # for SVD
                                new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
                                new_cond_item[cond_key].cond = context_length
                        resized_actual_cond[key] = new_cond_item
                    else:
                        resized_actual_cond[key] = cond_item
                finally:
                    del cond_item  # just in case to prevent VRAM issues
            resized_cond.append(resized_actual_cond)
        return resized_cond

    # get context windows
    ADGS.params.context_options.step = ADGS.current_step
    context_windows = get_context_windows(ADGS.params.full_length, ADGS.params.context_options)

    if ADGS.motion_models is not None:
        ADGS.motion_models.set_view_options(ADGS.params.context_options.view_options)
    
    # prepare final conds, out_counts, and biases
    conds_final = [torch.zeros_like(x_in) for _ in conds]
    if ADGS.params.context_options.fuse_method == ContextFuseMethod.RELATIVE:
        # counts_final not used for RELATIVE fuse_method
        counts_final = [torch.ones((x_in.shape[0], 1, 1, 1), device=x_in.device) for _ in conds]
    else:
        # default counts_final initialization
        counts_final = [torch.zeros((x_in.shape[0], 1, 1, 1), device=x_in.device) for _ in conds]
    biases_final = [([0.0] * x_in.shape[0]) for _ in conds]

    CONTEXTREF_CONTROL_LIST_ALL = "contextref_control_list_all"
    CONTEXTREF_MACHINE_STATE = "contextref_machine_state"
    CONTEXTREF_CLEAN_FUNC = "contextref_clean_func"
    contextref_active = False
    contextref_mode = None
    contextref_idxs_set = None
    first_context = True
    # need to make sure that contextref stuff gets cleaned up, no matter what
    try:
        if ADGS.params.context_options.extras.should_run_context_ref():
            # check that ACN provided ContextRef as requested
            temp_refcn_list = model_options["transformer_options"].get(CONTEXTREF_CONTROL_LIST_ALL, None)
            if temp_refcn_list is None:
                raise Exception("Advanced-ControlNet nodes are either missing or too outdated to support ContextRef. Update/install ComfyUI-Advanced-ControlNet to use ContextRef.")
            if len(temp_refcn_list) == 0:
                raise Exception("Unexpected ContextRef issue; Advanced-ControlNet did not provide any ContextRef objs for AnimateDiff-Evolved.")
            del temp_refcn_list
            # check if ContextRef ReferenceAdvanced ACN objs should_run
            actually_should_run = True
            for refcn in model_options["transformer_options"][CONTEXTREF_CONTROL_LIST_ALL]:
                refcn.prepare_current_timestep(timestep)
                if not refcn.should_run():
                    actually_should_run = False
            if actually_should_run:
                contextref_active = True
                for refcn in model_options["transformer_options"][CONTEXTREF_CONTROL_LIST_ALL]:
                    # get mode_override if present, mode otherwise
                    contextref_mode = refcn.get_contextref_mode_replace() or ADGS.params.context_options.extras.context_ref.mode
                contextref_idxs_set = contextref_mode.indexes.copy()

        curr_window_idx = -1
        naivereuse_active = False
        cached_naive_conds = None
        cached_naive_ctx_idxs = None
        if ADGS.params.context_options.extras.should_run_naive_reuse():
            cached_naive_conds = [torch.zeros_like(x_in) for _ in conds]
            #cached_naive_counts = [torch.zeros((x_in.shape[0], 1, 1, 1), device=x_in.device) for _ in conds]
            naivereuse_active = True
        # perform calc_conds_batch per context window 
        for ctx_idxs in context_windows:
            # allow processing to end between context window executions for faster Cancel
            comfy.model_management.throw_exception_if_processing_interrupted()
            curr_window_idx += 1
            ADGS.params.sub_idxs = ctx_idxs
            if ADGS.motion_models is not None:
                ADGS.motion_models.set_sub_idxs(ctx_idxs)
                ADGS.motion_models.set_video_length(len(ctx_idxs), ADGS.params.full_length)
            # update exposed params
            model_options["transformer_options"]["ad_params"]["sub_idxs"] = ctx_idxs
            model_options["transformer_options"]["ad_params"]["context_length"] = len(ctx_idxs)
            # get subsections of x, timestep, conds
            sub_x = x_in[ctx_idxs]
            sub_timestep = timestep[ctx_idxs]
            sub_conds = [get_resized_cond(cond, ctx_idxs, len(ctx_idxs)) for cond in conds]

            if contextref_active:
                # set cond counter to 0 (each cond encountered will increment it by 1)
                for refcn in model_options["transformer_options"][CONTEXTREF_CONTROL_LIST_ALL]:
                    refcn.contextref_cond_idx = 0
                if first_context:
                    model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.WRITE
                else:
                    model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.READ
                    if contextref_mode.mode == ContextRefMode.SLIDING: # if sliding, check if time to READ and WRITE
                        if curr_window_idx % (contextref_mode.sliding_width-1) == 0:
                            model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.READ_WRITE
                # override with indexes mode, if set
                if contextref_mode.mode == ContextRefMode.INDEXES:
                    contains_idx = False
                    for i in ctx_idxs:
                        if i in contextref_idxs_set:
                            contains_idx = True
                            # single trigger decides if each index should only trigger READ_WRITE once per step
                            if not contextref_mode.single_trigger:
                                break
                            contextref_idxs_set.remove(i)
                    if contains_idx:
                        model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.READ_WRITE
                        if first_context:
                            model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.WRITE
                    else:
                        model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.READ
            else:
                model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.OFF
            #logger.info(f"window: {curr_window_idx} - {model_options['transformer_options'][CONTEXTREF_MACHINE_STATE]}")

            sub_conds_out = executor(model, sub_conds, sub_x, sub_timestep, model_options)

            if ADGS.params.context_options.fuse_method == ContextFuseMethod.RELATIVE:
                full_length = ADGS.params.full_length
                for pos, idx in enumerate(ctx_idxs):
                    # bias is the influence of a specific index in relation to the whole context window
                    bias = 1 - abs(idx - (ctx_idxs[0] + ctx_idxs[-1]) / 2) / ((ctx_idxs[-1] - ctx_idxs[0] + 1e-2) / 2)
                    bias = max(1e-2, bias)
                    # take weighted average relative to total bias of current idx
                    for i in range(len(sub_conds_out)):
                        bias_total = biases_final[i][idx]
                        prev_weight = (bias_total / (bias_total + bias))
                        new_weight = (bias / (bias_total + bias))
                        conds_final[i][idx] = conds_final[i][idx] * prev_weight + sub_conds_out[i][pos] * new_weight
                        biases_final[i][idx] = bias_total + bias
            else:
                # add conds and counts based on weights of fuse method
                weights = get_context_weights(len(ctx_idxs), ADGS.params.context_options.fuse_method, sigma=timestep)
                weights_tensor = torch.Tensor(weights).to(device=x_in.device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
                for i in range(len(sub_conds_out)):
                    conds_final[i][ctx_idxs] += sub_conds_out[i] * weights_tensor
                    counts_final[i][ctx_idxs] += weights_tensor
            # handle NaiveReuse
            if naivereuse_active:
                cached_naive_ctx_idxs = ctx_idxs
                for i in range(len(sub_conds)):
                    cached_naive_conds[i][ctx_idxs] = conds_final[i][ctx_idxs] / counts_final[i][ctx_idxs]
                naivereuse_active = False
            # toggle first_context off, if needed
            if first_context:
                first_context = False
    finally:
        # clean contextref stuff with provided ACN function, if applicable
        if contextref_active:
            model_options["transformer_options"][CONTEXTREF_CLEAN_FUNC]()

    # handle NaiveReuse
    if cached_naive_conds is not None:
        start_idx = cached_naive_ctx_idxs[0]
        for z in range(0, ADGS.params.full_length, len(cached_naive_ctx_idxs)):
            for i in range(len(cached_naive_conds)):
                # get the 'true' idxs of this window
                new_ctx_idxs = [(zz+start_idx) % ADGS.params.full_length for zz in list(range(z, z+len(cached_naive_ctx_idxs))) if zz < ADGS.params.full_length]
                # make sure when getting cached_naive idxs, they are adjusted for actual length leftover length
                adjusted_naive_ctx_idxs = cached_naive_ctx_idxs[:len(new_ctx_idxs)]
                weighted_mean = ADGS.params.context_options.extras.naive_reuse.get_effective_weighted_mean(x_in, new_ctx_idxs)
                conds_final[i][new_ctx_idxs] = (weighted_mean * (cached_naive_conds[i][adjusted_naive_ctx_idxs]*counts_final[i][new_ctx_idxs])) + ((1.-weighted_mean) * conds_final[i][new_ctx_idxs])
        del cached_naive_conds

    if ADGS.params.context_options.fuse_method == ContextFuseMethod.RELATIVE:
        # already normalized, so return as is
        del counts_final
        return conds_final
    else:
        # normalize conds via division by context usage counts
        for i in range(len(conds_final)):
            conds_final[i] /= counts_final[i]
        del counts_final
        return conds_final


def get_conds_with_c_concat(conds: list[dict], c_concat: comfy.conds.CONDNoiseShape):
    new_conds = []
    for cond in conds:
        resized_cond = None
        if cond is not None:
            # reuse or resize cond items to match context requirements
            resized_cond = []
            # cond object is a list containing a dict - outer list is irrelevant, so just loop through it
            for actual_cond in cond:
                resized_actual_cond = actual_cond.copy()
                # now we are in the inner dict - "pooled_output" is a tensor, "control" is a ControlBase object, "model_conds" is dictionary
                for key in actual_cond:
                    if key == "model_conds":
                        new_model_conds = actual_cond[key].copy()
                        if "c_concat" in new_model_conds:
                            new_model_conds["c_concat"] = comfy.conds.CONDNoiseShape(torch.cat(new_model_conds["c_concat"].cond, c_concat.cond, dim=1))
                        else:
                            new_model_conds["c_concat"] = c_concat
                        resized_actual_cond[key] = new_model_conds
                resized_cond.append(resized_actual_cond)
        new_conds.append(resized_cond)
    return new_conds