jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
import torch
from collections import OrderedDict
from comfy import model_base
from comfy import utils
from comfy import diffusers_convert
try:
import comfy.text_encoders.sd2_clip
except ImportError:
from comfy import sd2_clip
from comfy import supported_models_base
from comfy import latent_formats
from ..lvdm.modules.encoders.resampler import Resampler
DYNAMICRAFTER_CONFIG = {
'in_channels': 8,
'out_channels': 4,
'model_channels': 320,
'attention_resolutions': [4, 2, 1],
'num_res_blocks': 2,
'channel_mult': [1, 2, 4, 4],
'num_head_channels': 64,
'transformer_depth': 1,
'context_dim': 1024,
'use_linear': True,
'use_checkpoint': False,
'temporal_conv': True,
'temporal_attention': True,
'temporal_selfatt_only': True,
'use_relative_position': False,
'use_causal_attention': False,
'temporal_length': 16,
'addition_attention': True,
'image_cross_attention': True,
'image_cross_attention_scale_learnable': True,
'default_fs': 3,
'fs_condition': True
}
IMAGE_PROJ_CONFIG = {
"dim": 1024,
"depth": 4,
"dim_head": 64,
"heads": 12,
"num_queries": 16,
"embedding_dim": 1280,
"output_dim": 1024,
"ff_mult": 4,
"video_length": 16
}
def process_list_or_str(target_key_or_keys, k):
if isinstance(target_key_or_keys, list):
return any([list_k in k for list_k in target_key_or_keys])
else:
return target_key_or_keys in k
def simple_state_dict_loader(state_dict: dict, target_key: str, target_dict: dict = None):
out_dict = {}
if target_dict is None:
for k, v in state_dict.items():
if process_list_or_str(target_key, k):
out_dict[k] = v
else:
for k, v in target_dict.items():
out_dict[k] = state_dict[k]
return out_dict
def load_image_proj_dict(state_dict: dict):
return simple_state_dict_loader(state_dict, 'image_proj')
def load_dynamicrafter_dict(state_dict: dict):
return simple_state_dict_loader(state_dict, 'model.diffusion_model')
def load_vae_dict(state_dict: dict):
return simple_state_dict_loader(state_dict, 'first_stage_model')
def get_base_model(state_dict: dict, version_checker=False):
is_256_model = False
for k in state_dict.keys():
if "framestride_embed" in k:
is_256_model = True
break
def get_image_proj_model(state_dict: dict):
state_dict = {k.replace('image_proj_model.', ''): v for k, v in state_dict.items()}
#target_dict = Resampler().state_dict()
ImageProjModel = Resampler(**IMAGE_PROJ_CONFIG)
ImageProjModel.load_state_dict(state_dict)
print("Image Projection Model loaded successfully")
#del target_dict
return ImageProjModel
class DynamiCrafterBase(supported_models_base.BASE):
unet_config = {}
unet_extra_config = {}
latent_format = latent_formats.SD15
def process_clip_state_dict(self, state_dict):
replace_prefix = {}
replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format
replace_prefix["cond_stage_model.model."] = "clip_h."
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_h.", "clip_h.transformer.")
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {}
replace_prefix["clip_h"] = "cond_stage_model.model"
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
return state_dict
def clip_target(self):
return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)
def process_dict_version(self, state_dict: dict):
processed_dict = OrderedDict()
is_eps = False
for k in list(state_dict.keys()):
if "framestride_embed" in k:
new_key = k.replace("framestride_embed", "fps_embedding")
processed_dict[new_key] = state_dict[k]
is_eps = True
continue
processed_dict[k] = state_dict[k]
return processed_dict, is_eps