File size: 4,355 Bytes
82ea528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147

import torch

from collections import OrderedDict

from comfy import model_base
from comfy import utils
from comfy import diffusers_convert

try:
    import comfy.text_encoders.sd2_clip
except ImportError:
    from comfy import sd2_clip

from comfy import supported_models_base
from comfy import latent_formats

from ..lvdm.modules.encoders.resampler import Resampler

DYNAMICRAFTER_CONFIG = {
    'in_channels': 8,
    'out_channels': 4,
    'model_channels': 320,
    'attention_resolutions': [4, 2, 1],
    'num_res_blocks': 2,
    'channel_mult': [1, 2, 4, 4],
    'num_head_channels': 64,
    'transformer_depth': 1,
    'context_dim': 1024,
    'use_linear': True,
    'use_checkpoint': False,
    'temporal_conv': True,
    'temporal_attention': True,
    'temporal_selfatt_only': True,
    'use_relative_position': False,
    'use_causal_attention': False,
    'temporal_length': 16,
    'addition_attention': True,
    'image_cross_attention': True,
    'image_cross_attention_scale_learnable': True,
    'default_fs': 3,
    'fs_condition': True
}

IMAGE_PROJ_CONFIG = {
    "dim": 1024,
    "depth": 4,
    "dim_head": 64,
    "heads": 12,
    "num_queries": 16,
    "embedding_dim": 1280,
    "output_dim": 1024,
    "ff_mult": 4,
    "video_length": 16
}

def process_list_or_str(target_key_or_keys, k):
    if isinstance(target_key_or_keys, list):
        return any([list_k in k for list_k in target_key_or_keys])
    else:
        return target_key_or_keys in k

def simple_state_dict_loader(state_dict: dict, target_key: str, target_dict: dict = None):
    out_dict = {}
    
    if target_dict is None:
        for k, v in state_dict.items():
            if process_list_or_str(target_key, k):
                out_dict[k] = v
    else:
        for k, v in target_dict.items():
            out_dict[k] = state_dict[k]
            
    return out_dict

def load_image_proj_dict(state_dict: dict):
    return simple_state_dict_loader(state_dict, 'image_proj')

def load_dynamicrafter_dict(state_dict: dict):
    return simple_state_dict_loader(state_dict, 'model.diffusion_model')

def load_vae_dict(state_dict: dict):
    return simple_state_dict_loader(state_dict, 'first_stage_model')

def get_base_model(state_dict: dict, version_checker=False):

    is_256_model = False

    for k in state_dict.keys():
         if "framestride_embed" in k:
            is_256_model = True
            break

def get_image_proj_model(state_dict: dict):

    state_dict = {k.replace('image_proj_model.', ''): v for k, v in state_dict.items()}
    #target_dict = Resampler().state_dict()

    ImageProjModel = Resampler(**IMAGE_PROJ_CONFIG)
    ImageProjModel.load_state_dict(state_dict)

    print("Image Projection Model loaded successfully")
    #del target_dict
    return ImageProjModel

class DynamiCrafterBase(supported_models_base.BASE):
    unet_config = {}
    unet_extra_config = {}

    latent_format = latent_formats.SD15

    def process_clip_state_dict(self, state_dict):
        replace_prefix = {}
        replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format
        replace_prefix["cond_stage_model.model."] = "clip_h."
        state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
        state_dict = utils.clip_text_transformers_convert(state_dict, "clip_h.", "clip_h.transformer.")
        return state_dict

    def process_clip_state_dict_for_saving(self, state_dict):
        replace_prefix = {}
        replace_prefix["clip_h"] = "cond_stage_model.model"
        state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
        state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
        return state_dict

    def clip_target(self):
        return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)

    def process_dict_version(self, state_dict: dict):
        processed_dict = OrderedDict()
        is_eps = False
        
        for k in list(state_dict.keys()):
            if "framestride_embed" in k:
                new_key = k.replace("framestride_embed", "fps_embedding")
                processed_dict[new_key] = state_dict[k]
                is_eps = True
                continue
            
            processed_dict[k] = state_dict[k]

        return processed_dict, is_eps