|
|
|
import torch |
|
|
|
from collections import OrderedDict |
|
|
|
from comfy import model_base |
|
from comfy import utils |
|
from comfy import diffusers_convert |
|
|
|
try: |
|
import comfy.text_encoders.sd2_clip |
|
except ImportError: |
|
from comfy import sd2_clip |
|
|
|
from comfy import supported_models_base |
|
from comfy import latent_formats |
|
|
|
from ..lvdm.modules.encoders.resampler import Resampler |
|
|
|
DYNAMICRAFTER_CONFIG = { |
|
'in_channels': 8, |
|
'out_channels': 4, |
|
'model_channels': 320, |
|
'attention_resolutions': [4, 2, 1], |
|
'num_res_blocks': 2, |
|
'channel_mult': [1, 2, 4, 4], |
|
'num_head_channels': 64, |
|
'transformer_depth': 1, |
|
'context_dim': 1024, |
|
'use_linear': True, |
|
'use_checkpoint': False, |
|
'temporal_conv': True, |
|
'temporal_attention': True, |
|
'temporal_selfatt_only': True, |
|
'use_relative_position': False, |
|
'use_causal_attention': False, |
|
'temporal_length': 16, |
|
'addition_attention': True, |
|
'image_cross_attention': True, |
|
'image_cross_attention_scale_learnable': True, |
|
'default_fs': 3, |
|
'fs_condition': True |
|
} |
|
|
|
IMAGE_PROJ_CONFIG = { |
|
"dim": 1024, |
|
"depth": 4, |
|
"dim_head": 64, |
|
"heads": 12, |
|
"num_queries": 16, |
|
"embedding_dim": 1280, |
|
"output_dim": 1024, |
|
"ff_mult": 4, |
|
"video_length": 16 |
|
} |
|
|
|
def process_list_or_str(target_key_or_keys, k): |
|
if isinstance(target_key_or_keys, list): |
|
return any([list_k in k for list_k in target_key_or_keys]) |
|
else: |
|
return target_key_or_keys in k |
|
|
|
def simple_state_dict_loader(state_dict: dict, target_key: str, target_dict: dict = None): |
|
out_dict = {} |
|
|
|
if target_dict is None: |
|
for k, v in state_dict.items(): |
|
if process_list_or_str(target_key, k): |
|
out_dict[k] = v |
|
else: |
|
for k, v in target_dict.items(): |
|
out_dict[k] = state_dict[k] |
|
|
|
return out_dict |
|
|
|
def load_image_proj_dict(state_dict: dict): |
|
return simple_state_dict_loader(state_dict, 'image_proj') |
|
|
|
def load_dynamicrafter_dict(state_dict: dict): |
|
return simple_state_dict_loader(state_dict, 'model.diffusion_model') |
|
|
|
def load_vae_dict(state_dict: dict): |
|
return simple_state_dict_loader(state_dict, 'first_stage_model') |
|
|
|
def get_base_model(state_dict: dict, version_checker=False): |
|
|
|
is_256_model = False |
|
|
|
for k in state_dict.keys(): |
|
if "framestride_embed" in k: |
|
is_256_model = True |
|
break |
|
|
|
def get_image_proj_model(state_dict: dict): |
|
|
|
state_dict = {k.replace('image_proj_model.', ''): v for k, v in state_dict.items()} |
|
|
|
|
|
ImageProjModel = Resampler(**IMAGE_PROJ_CONFIG) |
|
ImageProjModel.load_state_dict(state_dict) |
|
|
|
print("Image Projection Model loaded successfully") |
|
|
|
return ImageProjModel |
|
|
|
class DynamiCrafterBase(supported_models_base.BASE): |
|
unet_config = {} |
|
unet_extra_config = {} |
|
|
|
latent_format = latent_formats.SD15 |
|
|
|
def process_clip_state_dict(self, state_dict): |
|
replace_prefix = {} |
|
replace_prefix["conditioner.embedders.0.model."] = "clip_h." |
|
replace_prefix["cond_stage_model.model."] = "clip_h." |
|
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) |
|
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_h.", "clip_h.transformer.") |
|
return state_dict |
|
|
|
def process_clip_state_dict_for_saving(self, state_dict): |
|
replace_prefix = {} |
|
replace_prefix["clip_h"] = "cond_stage_model.model" |
|
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) |
|
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict) |
|
return state_dict |
|
|
|
def clip_target(self): |
|
return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel) |
|
|
|
def process_dict_version(self, state_dict: dict): |
|
processed_dict = OrderedDict() |
|
is_eps = False |
|
|
|
for k in list(state_dict.keys()): |
|
if "framestride_embed" in k: |
|
new_key = k.replace("framestride_embed", "fps_embedding") |
|
processed_dict[new_key] = state_dict[k] |
|
is_eps = True |
|
continue |
|
|
|
processed_dict[k] = state_dict[k] |
|
|
|
return processed_dict, is_eps |
|
|
|
|
|
|
|
|