Spaces:
Sleeping
Sleeping
File size: 7,323 Bytes
c92c0ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import argparse
import numpy as np
import torch
import os
import yaml
import random
from diffusers.utils.import_utils import is_accelerate_available
from transformers import CLIPTextModel, CLIPTokenizer
from migc.migc_pipeline import StableDiffusionMIGCPipeline, MIGCProcessor, AttentionStore
from diffusers import EulerDiscreteScheduler
if is_accelerate_available():
from accelerate import init_empty_weights
from contextlib import nullcontext
def seed_everything(seed):
# np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
import torch
from typing import Callable, Dict, List, Optional, Union
from collections import defaultdict
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
# We need to set Attention Processors for the following keys.
all_processor_keys = [
'down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor', 'down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor',
'down_blocks.0.attentions.1.transformer_blocks.0.attn1.processor', 'down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor',
'down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor', 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor',
'down_blocks.1.attentions.1.transformer_blocks.0.attn1.processor', 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor',
'down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor', 'down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor',
'down_blocks.2.attentions.1.transformer_blocks.0.attn1.processor', 'down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor',
'up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor', 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor',
'up_blocks.1.attentions.1.transformer_blocks.0.attn1.processor', 'up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor',
'up_blocks.1.attentions.2.transformer_blocks.0.attn1.processor', 'up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor',
'up_blocks.2.attentions.0.transformer_blocks.0.attn1.processor', 'up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor',
'up_blocks.2.attentions.1.transformer_blocks.0.attn1.processor', 'up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor',
'up_blocks.2.attentions.2.transformer_blocks.0.attn1.processor', 'up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor',
'up_blocks.3.attentions.0.transformer_blocks.0.attn1.processor', 'up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor',
'up_blocks.3.attentions.1.transformer_blocks.0.attn1.processor', 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor',
'up_blocks.3.attentions.2.transformer_blocks.0.attn1.processor', 'up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor',
'mid_block.attentions.0.transformer_blocks.0.attn1.processor', 'mid_block.attentions.0.transformer_blocks.0.attn2.processor'
]
def load_migc(unet, attention_store, pretrained_MIGC_path: Union[str, Dict[str, torch.Tensor]], attn_processor,
**kwargs):
state_dict = torch.load(pretrained_MIGC_path, map_location="cpu")
# fill attn processors
attn_processors = {}
state_dict = state_dict['state_dict']
adapter_grouped_dict = defaultdict(dict)
# change the key of MIGC.ckpt as the form of diffusers unet
for key, value in state_dict.items():
key_list = key.split(".")
assert 'migc' in key_list
if 'input_blocks' in key_list:
model_type = 'down_blocks'
elif 'middle_block' in key_list:
model_type = 'mid_block'
else:
model_type = 'up_blocks'
index_number = int(key_list[3])
if model_type == 'down_blocks':
input_num1 = str(index_number//3)
input_num2 = str((index_number%3)-1)
elif model_type == 'mid_block':
input_num1 = '0'
input_num2 = '0'
else:
input_num1 = str(index_number//3)
input_num2 = str(index_number%3)
attn_key_list = [model_type,input_num1,'attentions',input_num2,'transformer_blocks','0']
if model_type == 'mid_block':
attn_key_list = [model_type,'attentions',input_num2,'transformer_blocks','0']
attn_processor_key = '.'.join(attn_key_list)
sub_key = '.'.join(key_list[key_list.index('migc'):])
adapter_grouped_dict[attn_processor_key][sub_key] = value
# Create MIGC Processor
config = {'not_use_migc': False}
for key, value_dict in adapter_grouped_dict.items():
dim = value_dict['migc.norm.bias'].shape[0]
config['C'] = dim
key_final = key + '.attn2.processor'
if key_final.startswith("mid_block"):
place_in_unet = "mid"
elif key_final.startswith("up_blocks"):
place_in_unet = "up"
elif key_final.startswith("down_blocks"):
place_in_unet = "down"
attn_processors[key_final] = attn_processor(config, attention_store, place_in_unet)
attn_processors[key_final].load_state_dict(value_dict)
attn_processors[key_final].to(device=unet.device, dtype=unet.dtype)
# Create CrossAttention/SelfAttention Processor
config = {'not_use_migc': True}
for key in all_processor_keys:
if key not in attn_processors.keys():
if key.startswith("mid_block"):
place_in_unet = "mid"
elif key.startswith("up_blocks"):
place_in_unet = "up"
elif key.startswith("down_blocks"):
place_in_unet = "down"
attn_processors[key] = attn_processor(config, attention_store, place_in_unet)
unet.set_attn_processor(attn_processors)
attention_store.num_att_layers = 32
def offlinePipelineSetupWithSafeTensor(sd_safetensors_path):
project_dir = os.path.dirname(os.path.dirname(__file__))
migc_ckpt_path = os.path.join(project_dir, 'pretrained_weights/MIGC_SD14.ckpt')
clip_model_path = os.path.join(project_dir, 'migc_gui_weights/clip/text_encoder')
clip_tokenizer_path = os.path.join(project_dir, 'migc_gui_weights/clip/tokenizer')
original_config_file = os.path.join(project_dir, 'migc_gui_weights/v1-inference.yaml')
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
# text_encoder = CLIPTextModel(config)
text_encoder = CLIPTextModel.from_pretrained(clip_model_path)
tokenizer = CLIPTokenizer.from_pretrained(clip_tokenizer_path)
pipe = StableDiffusionMIGCPipeline.from_single_file(sd_safetensors_path,
original_config_file=original_config_file,
text_encoder=text_encoder,
tokenizer=tokenizer,
load_safety_checker=False)
print('Initializing pipeline')
pipe.attention_store = AttentionStore()
from migc.migc_utils import load_migc
load_migc(pipe.unet , pipe.attention_store,
migc_ckpt_path, attn_processor=MIGCProcessor)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
return pipe |