from .sd_motion import TemporalBlock import torch class SDXLMotionModel(torch.nn.Module): def __init__(self): super().__init__() self.motion_modules = torch.nn.ModuleList([ TemporalBlock(8, 320//8, 320, eps=1e-6), TemporalBlock(8, 320//8, 320, eps=1e-6), TemporalBlock(8, 640//8, 640, eps=1e-6), TemporalBlock(8, 640//8, 640, eps=1e-6), TemporalBlock(8, 1280//8, 1280, eps=1e-6), TemporalBlock(8, 1280//8, 1280, eps=1e-6), TemporalBlock(8, 1280//8, 1280, eps=1e-6), TemporalBlock(8, 1280//8, 1280, eps=1e-6), TemporalBlock(8, 1280//8, 1280, eps=1e-6), TemporalBlock(8, 640//8, 640, eps=1e-6), TemporalBlock(8, 640//8, 640, eps=1e-6), TemporalBlock(8, 640//8, 640, eps=1e-6), TemporalBlock(8, 320//8, 320, eps=1e-6), TemporalBlock(8, 320//8, 320, eps=1e-6), TemporalBlock(8, 320//8, 320, eps=1e-6), ]) self.call_block_id = { 0: 0, 2: 1, 7: 2, 10: 3, 15: 4, 18: 5, 25: 6, 28: 7, 31: 8, 35: 9, 38: 10, 41: 11, 44: 12, 46: 13, 48: 14, } def forward(self): pass @staticmethod def state_dict_converter(): return SDMotionModelStateDictConverter() class SDMotionModelStateDictConverter: def __init__(self): pass def from_diffusers(self, state_dict): rename_dict = { "norm": "norm", "proj_in": "proj_in", "transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q", "transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k", "transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v", "transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out", "transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1", "transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q", "transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k", "transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v", "transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out", "transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2", "transformer_blocks.0.norms.0": "transformer_blocks.0.norm1", "transformer_blocks.0.norms.1": "transformer_blocks.0.norm2", "transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj", "transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff", "transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3", "proj_out": "proj_out", } name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")]) name_list += sorted([i for i in state_dict if i.startswith("mid_block.")]) name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")]) state_dict_ = {} last_prefix, module_id = "", -1 for name in name_list: names = name.split(".") prefix_index = names.index("temporal_transformer") + 1 prefix = ".".join(names[:prefix_index]) if prefix != last_prefix: last_prefix = prefix module_id += 1 middle_name = ".".join(names[prefix_index:-1]) suffix = names[-1] if "pos_encoder" in names: rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]]) else: rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix]) state_dict_[rename] = state_dict[name] return state_dict_ def from_civitai(self, state_dict): return self.from_diffusers(state_dict)