Bbmyy
first commit
c92c0ec
import argparse
import numpy as np
import torch
import os
import yaml
import random
from diffusers.utils.import_utils import is_accelerate_available
from transformers import CLIPTextModel, CLIPTokenizer
from migc.migc_pipeline import StableDiffusionMIGCPipeline, MIGCProcessor, AttentionStore
from diffusers import EulerDiscreteScheduler
if is_accelerate_available():
from accelerate import init_empty_weights
from contextlib import nullcontext
def seed_everything(seed):
# np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
import torch
from typing import Callable, Dict, List, Optional, Union
from collections import defaultdict
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
# We need to set Attention Processors for the following keys.
all_processor_keys = [
'down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor', 'down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor',
'down_blocks.0.attentions.1.transformer_blocks.0.attn1.processor', 'down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor',
'down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor', 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor',
'down_blocks.1.attentions.1.transformer_blocks.0.attn1.processor', 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor',
'down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor', 'down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor',
'down_blocks.2.attentions.1.transformer_blocks.0.attn1.processor', 'down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor',
'up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor', 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor',
'up_blocks.1.attentions.1.transformer_blocks.0.attn1.processor', 'up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor',
'up_blocks.1.attentions.2.transformer_blocks.0.attn1.processor', 'up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor',
'up_blocks.2.attentions.0.transformer_blocks.0.attn1.processor', 'up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor',
'up_blocks.2.attentions.1.transformer_blocks.0.attn1.processor', 'up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor',
'up_blocks.2.attentions.2.transformer_blocks.0.attn1.processor', 'up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor',
'up_blocks.3.attentions.0.transformer_blocks.0.attn1.processor', 'up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor',
'up_blocks.3.attentions.1.transformer_blocks.0.attn1.processor', 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor',
'up_blocks.3.attentions.2.transformer_blocks.0.attn1.processor', 'up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor',
'mid_block.attentions.0.transformer_blocks.0.attn1.processor', 'mid_block.attentions.0.transformer_blocks.0.attn2.processor'
]
def load_migc(unet, attention_store, pretrained_MIGC_path: Union[str, Dict[str, torch.Tensor]], attn_processor,
**kwargs):
state_dict = torch.load(pretrained_MIGC_path, map_location="cpu")
# fill attn processors
attn_processors = {}
state_dict = state_dict['state_dict']
adapter_grouped_dict = defaultdict(dict)
# change the key of MIGC.ckpt as the form of diffusers unet
for key, value in state_dict.items():
key_list = key.split(".")
assert 'migc' in key_list
if 'input_blocks' in key_list:
model_type = 'down_blocks'
elif 'middle_block' in key_list:
model_type = 'mid_block'
else:
model_type = 'up_blocks'
index_number = int(key_list[3])
if model_type == 'down_blocks':
input_num1 = str(index_number//3)
input_num2 = str((index_number%3)-1)
elif model_type == 'mid_block':
input_num1 = '0'
input_num2 = '0'
else:
input_num1 = str(index_number//3)
input_num2 = str(index_number%3)
attn_key_list = [model_type,input_num1,'attentions',input_num2,'transformer_blocks','0']
if model_type == 'mid_block':
attn_key_list = [model_type,'attentions',input_num2,'transformer_blocks','0']
attn_processor_key = '.'.join(attn_key_list)
sub_key = '.'.join(key_list[key_list.index('migc'):])
adapter_grouped_dict[attn_processor_key][sub_key] = value
# Create MIGC Processor
config = {'not_use_migc': False}
for key, value_dict in adapter_grouped_dict.items():
dim = value_dict['migc.norm.bias'].shape[0]
config['C'] = dim
key_final = key + '.attn2.processor'
if key_final.startswith("mid_block"):
place_in_unet = "mid"
elif key_final.startswith("up_blocks"):
place_in_unet = "up"
elif key_final.startswith("down_blocks"):
place_in_unet = "down"
attn_processors[key_final] = attn_processor(config, attention_store, place_in_unet)
attn_processors[key_final].load_state_dict(value_dict)
attn_processors[key_final].to(device=unet.device, dtype=unet.dtype)
# Create CrossAttention/SelfAttention Processor
config = {'not_use_migc': True}
for key in all_processor_keys:
if key not in attn_processors.keys():
if key.startswith("mid_block"):
place_in_unet = "mid"
elif key.startswith("up_blocks"):
place_in_unet = "up"
elif key.startswith("down_blocks"):
place_in_unet = "down"
attn_processors[key] = attn_processor(config, attention_store, place_in_unet)
unet.set_attn_processor(attn_processors)
attention_store.num_att_layers = 32
def offlinePipelineSetupWithSafeTensor(sd_safetensors_path):
project_dir = os.path.dirname(os.path.dirname(__file__))
migc_ckpt_path = os.path.join(project_dir, 'pretrained_weights/MIGC_SD14.ckpt')
clip_model_path = os.path.join(project_dir, 'migc_gui_weights/clip/text_encoder')
clip_tokenizer_path = os.path.join(project_dir, 'migc_gui_weights/clip/tokenizer')
original_config_file = os.path.join(project_dir, 'migc_gui_weights/v1-inference.yaml')
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
# text_encoder = CLIPTextModel(config)
text_encoder = CLIPTextModel.from_pretrained(clip_model_path)
tokenizer = CLIPTokenizer.from_pretrained(clip_tokenizer_path)
pipe = StableDiffusionMIGCPipeline.from_single_file(sd_safetensors_path,
original_config_file=original_config_file,
text_encoder=text_encoder,
tokenizer=tokenizer,
load_safety_checker=False)
print('Initializing pipeline')
pipe.attention_store = AttentionStore()
from migc.migc_utils import load_migc
load_migc(pipe.unet , pipe.attention_store,
migc_ckpt_path, attn_processor=MIGCProcessor)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
return pipe