Spaces:
Sleeping
Sleeping
import argparse | |
import numpy as np | |
import torch | |
import os | |
import yaml | |
import random | |
from diffusers.utils.import_utils import is_accelerate_available | |
from transformers import CLIPTextModel, CLIPTokenizer | |
from migc.migc_pipeline import StableDiffusionMIGCPipeline, MIGCProcessor, AttentionStore | |
from diffusers import EulerDiscreteScheduler | |
if is_accelerate_available(): | |
from accelerate import init_empty_weights | |
from contextlib import nullcontext | |
def seed_everything(seed): | |
# np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
random.seed(seed) | |
import torch | |
from typing import Callable, Dict, List, Optional, Union | |
from collections import defaultdict | |
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" | |
# We need to set Attention Processors for the following keys. | |
all_processor_keys = [ | |
'down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor', 'down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor', | |
'down_blocks.0.attentions.1.transformer_blocks.0.attn1.processor', 'down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor', | |
'down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor', 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor', | |
'down_blocks.1.attentions.1.transformer_blocks.0.attn1.processor', 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor', | |
'down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor', 'down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor', | |
'down_blocks.2.attentions.1.transformer_blocks.0.attn1.processor', 'down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor', | |
'up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor', 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor', | |
'up_blocks.1.attentions.1.transformer_blocks.0.attn1.processor', 'up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor', | |
'up_blocks.1.attentions.2.transformer_blocks.0.attn1.processor', 'up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor', | |
'up_blocks.2.attentions.0.transformer_blocks.0.attn1.processor', 'up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor', | |
'up_blocks.2.attentions.1.transformer_blocks.0.attn1.processor', 'up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor', | |
'up_blocks.2.attentions.2.transformer_blocks.0.attn1.processor', 'up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor', | |
'up_blocks.3.attentions.0.transformer_blocks.0.attn1.processor', 'up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor', | |
'up_blocks.3.attentions.1.transformer_blocks.0.attn1.processor', 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor', | |
'up_blocks.3.attentions.2.transformer_blocks.0.attn1.processor', 'up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor', | |
'mid_block.attentions.0.transformer_blocks.0.attn1.processor', 'mid_block.attentions.0.transformer_blocks.0.attn2.processor' | |
] | |
def load_migc(unet, attention_store, pretrained_MIGC_path: Union[str, Dict[str, torch.Tensor]], attn_processor, | |
**kwargs): | |
state_dict = torch.load(pretrained_MIGC_path, map_location="cpu") | |
# fill attn processors | |
attn_processors = {} | |
state_dict = state_dict['state_dict'] | |
adapter_grouped_dict = defaultdict(dict) | |
# change the key of MIGC.ckpt as the form of diffusers unet | |
for key, value in state_dict.items(): | |
key_list = key.split(".") | |
assert 'migc' in key_list | |
if 'input_blocks' in key_list: | |
model_type = 'down_blocks' | |
elif 'middle_block' in key_list: | |
model_type = 'mid_block' | |
else: | |
model_type = 'up_blocks' | |
index_number = int(key_list[3]) | |
if model_type == 'down_blocks': | |
input_num1 = str(index_number//3) | |
input_num2 = str((index_number%3)-1) | |
elif model_type == 'mid_block': | |
input_num1 = '0' | |
input_num2 = '0' | |
else: | |
input_num1 = str(index_number//3) | |
input_num2 = str(index_number%3) | |
attn_key_list = [model_type,input_num1,'attentions',input_num2,'transformer_blocks','0'] | |
if model_type == 'mid_block': | |
attn_key_list = [model_type,'attentions',input_num2,'transformer_blocks','0'] | |
attn_processor_key = '.'.join(attn_key_list) | |
sub_key = '.'.join(key_list[key_list.index('migc'):]) | |
adapter_grouped_dict[attn_processor_key][sub_key] = value | |
# Create MIGC Processor | |
config = {'not_use_migc': False} | |
for key, value_dict in adapter_grouped_dict.items(): | |
dim = value_dict['migc.norm.bias'].shape[0] | |
config['C'] = dim | |
key_final = key + '.attn2.processor' | |
if key_final.startswith("mid_block"): | |
place_in_unet = "mid" | |
elif key_final.startswith("up_blocks"): | |
place_in_unet = "up" | |
elif key_final.startswith("down_blocks"): | |
place_in_unet = "down" | |
attn_processors[key_final] = attn_processor(config, attention_store, place_in_unet) | |
attn_processors[key_final].load_state_dict(value_dict) | |
attn_processors[key_final].to(device=unet.device, dtype=unet.dtype) | |
# Create CrossAttention/SelfAttention Processor | |
config = {'not_use_migc': True} | |
for key in all_processor_keys: | |
if key not in attn_processors.keys(): | |
if key.startswith("mid_block"): | |
place_in_unet = "mid" | |
elif key.startswith("up_blocks"): | |
place_in_unet = "up" | |
elif key.startswith("down_blocks"): | |
place_in_unet = "down" | |
attn_processors[key] = attn_processor(config, attention_store, place_in_unet) | |
unet.set_attn_processor(attn_processors) | |
attention_store.num_att_layers = 32 | |
def offlinePipelineSetupWithSafeTensor(sd_safetensors_path): | |
project_dir = os.path.dirname(os.path.dirname(__file__)) | |
migc_ckpt_path = os.path.join(project_dir, 'pretrained_weights/MIGC_SD14.ckpt') | |
clip_model_path = os.path.join(project_dir, 'migc_gui_weights/clip/text_encoder') | |
clip_tokenizer_path = os.path.join(project_dir, 'migc_gui_weights/clip/tokenizer') | |
original_config_file = os.path.join(project_dir, 'migc_gui_weights/v1-inference.yaml') | |
ctx = init_empty_weights if is_accelerate_available() else nullcontext | |
with ctx(): | |
# text_encoder = CLIPTextModel(config) | |
text_encoder = CLIPTextModel.from_pretrained(clip_model_path) | |
tokenizer = CLIPTokenizer.from_pretrained(clip_tokenizer_path) | |
pipe = StableDiffusionMIGCPipeline.from_single_file(sd_safetensors_path, | |
original_config_file=original_config_file, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
load_safety_checker=False) | |
print('Initializing pipeline') | |
pipe.attention_store = AttentionStore() | |
from migc.migc_utils import load_migc | |
load_migc(pipe.unet , pipe.attention_store, | |
migc_ckpt_path, attn_processor=MIGCProcessor) | |
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) | |
return pipe |