File size: 12,220 Bytes
82ea528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
from typing import Union
import torch
from torch import Tensor
import math

from comfy.sd import VAE

from .ad_settings import AnimateDiffSettings
from .logger import logger
from .utils_model import BIGMIN, BIGMAX, get_available_motion_models
from .utils_motion import ADKeyframeGroup, InputPIA, InputPIA_Multival, extend_list_to_batch_size, extend_to_batch_size, prepare_mask_batch
from .motion_lora import MotionLoraList
from .model_injection import MotionModelGroup, MotionModelPatcher, get_mm_attachment, load_motion_module_gen2, inject_pia_conv_in_into_model
from .motion_module_ad import AnimateDiffFormat
from .nodes_gen2 import ApplyAnimateDiffModelNode, ADKeyframeNode


# Preset values ported over from PIA repository:
# https://github.com/open-mmlab/PIA/blob/main/animatediff/utils/util.py
class PIA_RANGES:
    ANIMATION_SMALL = "Animation (Small Motion)"
    ANIMATION_MEDIUM = "Animation (Medium Motion)"
    ANIMATION_LARGE = "Animation (Large Motion)"
    LOOP_SMALL = "Loop (Small Motion)"
    LOOP_MEDIUM = "Loop (Medium Motion)"
    LOOP_LARGE = "Loop (Large Motion)"
    STYLE_TRANSFER_SMALL = "Style Transfer (Small Motion)"
    STYLE_TRANSFER_MEDIUM = "Style Transfer (Medium Motion)"
    STYLE_TRANSFER_LARGE = "Style Transfer (Large Motion)"

    _LOOPED = [LOOP_SMALL, LOOP_MEDIUM, LOOP_LARGE]
    _LIST_ALL = [ANIMATION_SMALL, ANIMATION_MEDIUM, ANIMATION_LARGE,
                 LOOP_SMALL, LOOP_MEDIUM, LOOP_LARGE,
                 STYLE_TRANSFER_SMALL, STYLE_TRANSFER_MEDIUM, STYLE_TRANSFER_LARGE]

    _MAPPING = {
        ANIMATION_SMALL: [1.0, 0.9, 0.85, 0.85, 0.85, 0.8],
        ANIMATION_MEDIUM: [1.0, 0.8, 0.8, 0.8, 0.79, 0.78, 0.75],
        ANIMATION_LARGE: [1.0, 0.8, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.6, 0.5, 0.5],
        LOOP_SMALL: [1.0, 0.9, 0.85, 0.85, 0.85, 0.8],
        LOOP_MEDIUM: [1.0, 0.8, 0.8, 0.8, 0.79, 0.78, 0.75],
        LOOP_LARGE: [1.0, 0.8, 0.7, 0.7, 0.7, 0.7, 0.6, 0.5],
        STYLE_TRANSFER_SMALL: [0.5, 0.4, 0.4, 0.4, 0.35, 0.3],
        STYLE_TRANSFER_MEDIUM: [0.5, 0.4, 0.4, 0.4, 0.35, 0.35, 0.3, 0.25, 0.2],
        STYLE_TRANSFER_LARGE: [0.5, 0.2],
    }

    @classmethod
    def get_preset(cls, preset: str) -> list[float]:
        if preset in cls._MAPPING:
            return cls._MAPPING[preset]
        raise Exception(f"PIA Preset '{preset}' is not recognized.")
    
    @classmethod
    def is_looped(cls, preset: str) -> bool:
        return preset in cls._LOOPED


class InputPIA_PaperPresets(InputPIA):
    def __init__(self, preset: str, index: int, mult_multival: Union[float, Tensor]=None, effect_multival: Union[float, Tensor]=None):
        super().__init__(effect_multival=effect_multival)
        self.preset = preset
        self.index = index
        self.mult_multival = mult_multival if mult_multival is not None else 1.0
    
    def get_mask(self, x: Tensor):
        b, c, h, w = x.shape
        values = PIA_RANGES.get_preset(self.preset)
        # if preset is looped, make values loop
        if PIA_RANGES.is_looped(self.preset):
            # even length
            if b % 2 == 0:
                # extend to half length to get half of the loop
                values = extend_list_to_batch_size(values, b // 2)
                # apply second half of loop (just reverse it)
                values += list(reversed(values))
            # odd length
            else:
                inter_values = extend_list_to_batch_size(values, b // 2)
                middle_vals = [values[min(len(inter_values), len(values)-1)]]
                # make middle vals long enough to fill in gaps (or none if not needed)
                middle_vals = middle_vals * (max(0, b-2*len(inter_values)))
                values = inter_values + middle_vals + list(reversed(inter_values))
        # otherwise, just extend values to desired length
        else:
            values = extend_list_to_batch_size(values, b)
        assert len(values) == b

        index = self.index
        # handle negative index
        if index < 0:
            index = b + index
        # constrain index between 0 and b-1
        index = max(0, min(b-1, index))
        # center values around targer index
        order = [abs(i - index) for i in range(b)]
        real_values = [values[order[i]] for i in range(b)]
        # using real values, generate masks
        tensor_values = torch.tensor(real_values).unsqueeze(-1).unsqueeze(-1)
        mask = torch.ones(size=(b, h, w)) * tensor_values
        # apply multi_multival to mask
        if type(self.mult_multival) == Tensor or not math.isclose(self.mult_multival, 1.0):
            real_mult = self.mult_multival
            if type(real_mult) == Tensor:
                real_mult = extend_to_batch_size(prepare_mask_batch(real_mult, x.shape), b).squeeze(1)
            mask = mask * real_mult
        return mask


class ApplyAnimateDiffPIAModel:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "motion_model": ("MOTION_MODEL_ADE",),
                "image": ("IMAGE",),
                "vae": ("VAE",),
                "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
                "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
            },
            "optional": {
                "pia_input": ("PIA_INPUT",),
                "motion_lora": ("MOTION_LORA",),
                "scale_multival": ("MULTIVAL",),
                "effect_multival": ("MULTIVAL",),
                "ad_keyframes": ("AD_KEYFRAMES",),
                "prev_m_models": ("M_MODELS",),
                "per_block": ("PER_BLOCK",),
                "autosize": ("ADEAUTOSIZE", {"padding": 0}),
            }
        }

    RETURN_TYPES = ("M_MODELS",)
    CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/β‘‘ Gen2 nodes β‘‘/PIA"
    FUNCTION = "apply_motion_model"

    def apply_motion_model(self, motion_model: MotionModelPatcher, image: Tensor, vae: VAE,
                           start_percent: float=0.0, end_percent: float=1.0, pia_input: InputPIA=None,
                           motion_lora: MotionLoraList=None, ad_keyframes: ADKeyframeGroup=None,
                           scale_multival=None, effect_multival=None, ref_multival=None, per_block=None,
                           prev_m_models: MotionModelGroup=None,):
        new_m_models = ApplyAnimateDiffModelNode.apply_motion_model(self, motion_model, start_percent=start_percent, end_percent=end_percent,
                                                                    motion_lora=motion_lora, ad_keyframes=ad_keyframes,
                                                                    scale_multival=scale_multival, effect_multival=effect_multival, per_block=per_block,
                                                                    prev_m_models=prev_m_models)
        # most recent added model will always be first in list;
        curr_model = new_m_models[0].models[0]
        # confirm that model is PIA
        if curr_model.model.mm_info.mm_format != AnimateDiffFormat.PIA:
            raise Exception(f"Motion model '{curr_model.model.mm_info.mm_name}' is not a PIA model; cannot be used with Apply AnimateDiff-PIA Model node.")
        attachment = get_mm_attachment(curr_model)
        attachment.orig_pia_images = image
        attachment.pia_vae = vae
        if pia_input is None:
            pia_input = InputPIA_Multival(1.0)
        attachment.pia_input = pia_input
        #curr_model.pia_multival = ref_multival
        return new_m_models


class LoadAnimateDiffAndInjectPIANode:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "model_name": (get_available_motion_models(),),
                "motion_model": ("MOTION_MODEL_ADE",),
            },
            "optional": {
                "ad_settings": ("AD_SETTINGS",),
                "deprecation_warning": ("ADEWARN", {"text": "Experimental. Don't expect to work.", "warn_type": "experimental", "color": "#CFC"}),
            }
        }
    
    RETURN_TYPES = ("MOTION_MODEL_ADE",)
    RETURN_NAMES = ("MOTION_MODEL",)

    CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/β‘‘ Gen2 nodes β‘‘/PIA/πŸ§ͺexperimental"
    FUNCTION = "load_motion_model"
    
    def load_motion_model(self, model_name: str, motion_model: MotionModelPatcher, ad_settings: AnimateDiffSettings=None):
        # make sure model actually has PIA conv_in
        if motion_model.model.conv_in is None:
            raise Exception("Passed-in motion model was expected to be PIA (contain conv_in), but did not.")
        # load motion module and motion settings, if included
        loaded_motion_model = load_motion_module_gen2(model_name=model_name, motion_model_settings=ad_settings)
        inject_pia_conv_in_into_model(motion_model=loaded_motion_model, w_pia=motion_model)
        return (loaded_motion_model,)


class PIA_ADKeyframeNode:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}, ),
            },
            "optional": {
                "prev_ad_keyframes": ("AD_KEYFRAMES", ),
                "scale_multival": ("MULTIVAL",),
                "effect_multival": ("MULTIVAL",),
                "pia_input": ("PIA_INPUT",),
                "inherit_missing": ("BOOLEAN", {"default": True}, ),
                "guarantee_steps": ("INT", {"default": 1, "min": 0, "max": BIGMAX}),
                "autosize": ("ADEAUTOSIZE", {"padding": 0}),
            }
        }
    
    RETURN_TYPES = ("AD_KEYFRAMES", )
    FUNCTION = "load_keyframe"

    CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/β‘‘ Gen2 nodes β‘‘/PIA"

    def load_keyframe(self,
                      start_percent: float, prev_ad_keyframes=None,
                      scale_multival: Union[float, torch.Tensor]=None, effect_multival: Union[float, torch.Tensor]=None,
                      pia_input: InputPIA=None,
                      inherit_missing: bool=True, guarantee_steps: int=1):
        return ADKeyframeNode.load_keyframe(self,
                    start_percent=start_percent, prev_ad_keyframes=prev_ad_keyframes,
                    scale_multival=scale_multival, effect_multival=effect_multival, pia_input=pia_input,
                    inherit_missing=inherit_missing, guarantee_steps=guarantee_steps
                )


class InputPIA_MultivalNode:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "multival": ("MULTIVAL",),
            },
            # "optional": {
            #     "effect_multival": ("MULTIVAL",),
            # }
        }
    
    RETURN_TYPES = ("PIA_INPUT",)
    CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/β‘‘ Gen2 nodes β‘‘/PIA"
    FUNCTION = "create_pia_input"

    def create_pia_input(self, multival: Union[float, Tensor], effect_multival: Union[float, Tensor]=None):
        return (InputPIA_Multival(multival, effect_multival),)


class InputPIA_PaperPresetsNode:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "preset": (PIA_RANGES._LIST_ALL,),
                "batch_index": ("INT", {"default": 0, "min": BIGMIN, "max": BIGMAX, "step": 1}),
            },
            "optional": {
                "mult_multival": ("MULTIVAL",),
                "print_values": ("BOOLEAN", {"default": False},),
                "autosize": ("ADEAUTOSIZE", {"padding": 0}),
                #"effect_multival": ("MULTIVAL",),
            }
        }

    RETURN_TYPES = ("PIA_INPUT",)
    CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/β‘‘ Gen2 nodes β‘‘/PIA"
    FUNCTION = "create_pia_input"

    def create_pia_input(self, preset: str, batch_index: int, mult_multival: Union[float, Tensor]=None, print_values: bool=False, effect_multival: Union[float, Tensor]=None):
        # verify preset exists - function will throw error if does not
        values = PIA_RANGES.get_preset(preset)
        if print_values:
            logger.info(f"PIA Preset '{preset}': {values}")
        return (InputPIA_PaperPresets(preset=preset, index=batch_index, mult_multival=mult_multival, effect_multival=effect_multival),)