File size: 4,355 Bytes
82ea528 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import torch
from collections import OrderedDict
from comfy import model_base
from comfy import utils
from comfy import diffusers_convert
try:
import comfy.text_encoders.sd2_clip
except ImportError:
from comfy import sd2_clip
from comfy import supported_models_base
from comfy import latent_formats
from ..lvdm.modules.encoders.resampler import Resampler
DYNAMICRAFTER_CONFIG = {
'in_channels': 8,
'out_channels': 4,
'model_channels': 320,
'attention_resolutions': [4, 2, 1],
'num_res_blocks': 2,
'channel_mult': [1, 2, 4, 4],
'num_head_channels': 64,
'transformer_depth': 1,
'context_dim': 1024,
'use_linear': True,
'use_checkpoint': False,
'temporal_conv': True,
'temporal_attention': True,
'temporal_selfatt_only': True,
'use_relative_position': False,
'use_causal_attention': False,
'temporal_length': 16,
'addition_attention': True,
'image_cross_attention': True,
'image_cross_attention_scale_learnable': True,
'default_fs': 3,
'fs_condition': True
}
IMAGE_PROJ_CONFIG = {
"dim": 1024,
"depth": 4,
"dim_head": 64,
"heads": 12,
"num_queries": 16,
"embedding_dim": 1280,
"output_dim": 1024,
"ff_mult": 4,
"video_length": 16
}
def process_list_or_str(target_key_or_keys, k):
if isinstance(target_key_or_keys, list):
return any([list_k in k for list_k in target_key_or_keys])
else:
return target_key_or_keys in k
def simple_state_dict_loader(state_dict: dict, target_key: str, target_dict: dict = None):
out_dict = {}
if target_dict is None:
for k, v in state_dict.items():
if process_list_or_str(target_key, k):
out_dict[k] = v
else:
for k, v in target_dict.items():
out_dict[k] = state_dict[k]
return out_dict
def load_image_proj_dict(state_dict: dict):
return simple_state_dict_loader(state_dict, 'image_proj')
def load_dynamicrafter_dict(state_dict: dict):
return simple_state_dict_loader(state_dict, 'model.diffusion_model')
def load_vae_dict(state_dict: dict):
return simple_state_dict_loader(state_dict, 'first_stage_model')
def get_base_model(state_dict: dict, version_checker=False):
is_256_model = False
for k in state_dict.keys():
if "framestride_embed" in k:
is_256_model = True
break
def get_image_proj_model(state_dict: dict):
state_dict = {k.replace('image_proj_model.', ''): v for k, v in state_dict.items()}
#target_dict = Resampler().state_dict()
ImageProjModel = Resampler(**IMAGE_PROJ_CONFIG)
ImageProjModel.load_state_dict(state_dict)
print("Image Projection Model loaded successfully")
#del target_dict
return ImageProjModel
class DynamiCrafterBase(supported_models_base.BASE):
unet_config = {}
unet_extra_config = {}
latent_format = latent_formats.SD15
def process_clip_state_dict(self, state_dict):
replace_prefix = {}
replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format
replace_prefix["cond_stage_model.model."] = "clip_h."
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_h.", "clip_h.transformer.")
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {}
replace_prefix["clip_h"] = "cond_stage_model.model"
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
return state_dict
def clip_target(self):
return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)
def process_dict_version(self, state_dict: dict):
processed_dict = OrderedDict()
is_eps = False
for k in list(state_dict.keys()):
if "framestride_embed" in k:
new_key = k.replace("framestride_embed", "fps_embedding")
processed_dict[new_key] = state_dict[k]
is_eps = True
continue
processed_dict[k] = state_dict[k]
return processed_dict, is_eps
|