import torch from collections import OrderedDict from comfy import model_base from comfy import utils from comfy import diffusers_convert try: import comfy.text_encoders.sd2_clip except ImportError: from comfy import sd2_clip from comfy import supported_models_base from comfy import latent_formats from ..lvdm.modules.encoders.resampler import Resampler DYNAMICRAFTER_CONFIG = { 'in_channels': 8, 'out_channels': 4, 'model_channels': 320, 'attention_resolutions': [4, 2, 1], 'num_res_blocks': 2, 'channel_mult': [1, 2, 4, 4], 'num_head_channels': 64, 'transformer_depth': 1, 'context_dim': 1024, 'use_linear': True, 'use_checkpoint': False, 'temporal_conv': True, 'temporal_attention': True, 'temporal_selfatt_only': True, 'use_relative_position': False, 'use_causal_attention': False, 'temporal_length': 16, 'addition_attention': True, 'image_cross_attention': True, 'image_cross_attention_scale_learnable': True, 'default_fs': 3, 'fs_condition': True } IMAGE_PROJ_CONFIG = { "dim": 1024, "depth": 4, "dim_head": 64, "heads": 12, "num_queries": 16, "embedding_dim": 1280, "output_dim": 1024, "ff_mult": 4, "video_length": 16 } def process_list_or_str(target_key_or_keys, k): if isinstance(target_key_or_keys, list): return any([list_k in k for list_k in target_key_or_keys]) else: return target_key_or_keys in k def simple_state_dict_loader(state_dict: dict, target_key: str, target_dict: dict = None): out_dict = {} if target_dict is None: for k, v in state_dict.items(): if process_list_or_str(target_key, k): out_dict[k] = v else: for k, v in target_dict.items(): out_dict[k] = state_dict[k] return out_dict def load_image_proj_dict(state_dict: dict): return simple_state_dict_loader(state_dict, 'image_proj') def load_dynamicrafter_dict(state_dict: dict): return simple_state_dict_loader(state_dict, 'model.diffusion_model') def load_vae_dict(state_dict: dict): return simple_state_dict_loader(state_dict, 'first_stage_model') def get_base_model(state_dict: dict, version_checker=False): is_256_model = False for k in state_dict.keys(): if "framestride_embed" in k: is_256_model = True break def get_image_proj_model(state_dict: dict): state_dict = {k.replace('image_proj_model.', ''): v for k, v in state_dict.items()} #target_dict = Resampler().state_dict() ImageProjModel = Resampler(**IMAGE_PROJ_CONFIG) ImageProjModel.load_state_dict(state_dict) print("Image Projection Model loaded successfully") #del target_dict return ImageProjModel class DynamiCrafterBase(supported_models_base.BASE): unet_config = {} unet_extra_config = {} latent_format = latent_formats.SD15 def process_clip_state_dict(self, state_dict): replace_prefix = {} replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format replace_prefix["cond_stage_model.model."] = "clip_h." state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) state_dict = utils.clip_text_transformers_convert(state_dict, "clip_h.", "clip_h.transformer.") return state_dict def process_clip_state_dict_for_saving(self, state_dict): replace_prefix = {} replace_prefix["clip_h"] = "cond_stage_model.model" state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict) return state_dict def clip_target(self): return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel) def process_dict_version(self, state_dict: dict): processed_dict = OrderedDict() is_eps = False for k in list(state_dict.keys()): if "framestride_embed" in k: new_key = k.replace("framestride_embed", "fps_embedding") processed_dict[new_key] = state_dict[k] is_eps = True continue processed_dict[k] = state_dict[k] return processed_dict, is_eps