File size: 8,996 Bytes
c92c0ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import os
import sys
import torch

migc_path = os.path.dirname(os.path.abspath(__file__))
print(migc_path)
if migc_path not in sys.path:
    sys.path.append(migc_path)
import yaml
from diffusers import EulerDiscreteScheduler
from migc.migc_utils import seed_everything
from migc.migc_pipeline import StableDiffusionMIGCPipeline, MIGCProcessor, AttentionStore

def normalize_bbox(bboxes, img_width, img_height):
    normalized_bboxes = []
    for box in bboxes:
        x_min, y_min, x_max, y_max = box
        
        x_min = x_min / img_width
        y_min = y_min / img_height
        x_max = x_max / img_width
        y_max = y_max / img_height
        
        normalized_bboxes.append([x_min, y_min, x_max, y_max])
    
    return [normalized_bboxes]

def create_simple_prompt(input_str):
    # 先将输入字符串按分号分割,并去掉空字符串
    objects = [obj for obj in input_str.split(';') if obj.strip()]
    
    # 创建详细描述字符串
    prompt_description = "masterpiece, best quality, " + ", ".join(objects)
    
    # 创建最终结构
    prompt_final = [[prompt_description] + objects]
    
    return prompt_final


def inference_single_image(prompt, grounding_instruction, state):
    print(prompt)
    print(grounding_instruction)
    bbox = state['boxes']
    print(bbox)
    bbox = normalize_bbox(bbox, 600, 600)
    print(bbox)
    simple_prompt = create_simple_prompt(grounding_instruction)
    print(simple_prompt)
    migc_ckpt_path = 'pretrained_weights/MIGC_SD14.ckpt'
    migc_ckpt_path_all = os.path.join(migc_path, migc_ckpt_path)
    print(migc_ckpt_path_all)
    assert os.path.isfile(migc_ckpt_path_all), "Please download the ckpt of migc and put it in the pretrained_weighrs/ folder!"


    sd1x_path = '/share/bcy/cache/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/133a221b8aa7292a167afc5127cb63fb5005638b' if os.path.isdir('/share/bcy/cache/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/133a221b8aa7292a167afc5127cb63fb5005638b') else "CompVis/stable-diffusion-v1-4"
    # MIGC is a plug-and-play controller.
    # You can go to https://civitai.com/search/models?baseModel=SD%201.4&baseModel=SD%201.5&sortBy=models_v5 find a base model with better generation ability to achieve better creations.
    
    # Construct MIGC pipeline
    pipe = StableDiffusionMIGCPipeline.from_pretrained(
        sd1x_path)
    pipe.attention_store = AttentionStore()
    from migc.migc_utils import load_migc
    load_migc(pipe.unet , pipe.attention_store,
            migc_ckpt_path_all, attn_processor=MIGCProcessor)
    pipe = pipe.to("cuda")
    pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
    

    # prompt_final = [['masterpiece, best quality,black colored ball,gray colored cat,white colored  bed,\
    #                  green colored plant,red colored teddy bear,blue colored wall,brown colored vase,orange colored book,\
    #                  yellow colored hat', 'black colored ball', 'gray colored cat', 'white colored  bed', 'green colored plant', \
    #                     'red colored teddy bear', 'blue colored wall', 'brown colored vase', 'orange colored book', 'yellow colored hat']]
    
    # bboxes = [[[0.3125, 0.609375, 0.625, 0.875], [0.5625, 0.171875, 0.984375, 0.6875], \
    #            [0.0, 0.265625, 0.984375, 0.984375], [0.0, 0.015625, 0.21875, 0.328125], \
    #             [0.171875, 0.109375, 0.546875, 0.515625], [0.234375, 0.0, 1.0, 0.3125], \
    #                 [0.71875, 0.625, 0.953125, 0.921875], [0.0625, 0.484375, 0.359375, 0.8125], \
    #                     [0.609375, 0.09375, 0.90625, 0.28125]]]
    negative_prompt = 'worst quality, low quality, bad anatomy, watermark, text, blurry'
    seed = 7351007268695528845
    seed_everything(seed)
    print("Start inference: ")
    image = pipe(simple_prompt, bbox, num_inference_steps=50, guidance_scale=7.5, 
                    MIGCsteps=25, aug_phase_with_and=False, negative_prompt=negative_prompt).images[0]
    return image




# def MIGC_Pipe():
#     migc_ckpt_path = 'pretrained_weights/MIGC_SD14.ckpt'
#     migc_ckpt_path_all = os.path.join(migc_path, migc_ckpt_path)
#     print(migc_ckpt_path_all)
#     assert os.path.isfile(migc_ckpt_path_all), "Please download the ckpt of migc and put it in the pretrained_weighrs/ folder!"
#     sd1x_path = '/share/bcy/cache/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/133a221b8aa7292a167afc5127cb63fb5005638b' if os.path.isdir('/share/bcy/cache/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/133a221b8aa7292a167afc5127cb63fb5005638b') else "CompVis/stable-diffusion-v1-4"
#     pipe = StableDiffusionMIGCPipeline.from_pretrained(
#         sd1x_path)
#     pipe.attention_store = AttentionStore()
#     from migc.migc_utils import load_migc
#     load_migc(pipe.unet , pipe.attention_store,
#             migc_ckpt_path_all, attn_processor=MIGCProcessor)
#     pipe = pipe.to("cuda")
#     pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
#     return pipe

def MIGC_Pipe():
    migc_ckpt_path = 'pretrained_weights/MIGC_SD14.ckpt'
    migc_ckpt_path_all = os.path.join(migc_path, migc_ckpt_path)
    print(f"加载 MIGC 权重文件路径: {migc_ckpt_path_all}")
    
    assert os.path.isfile(migc_ckpt_path_all), f"请下载 MIGC 的 ckpt 文件并将其放在 'pretrained_weights/' 文件夹中: {migc_ckpt_path_all}"
    
    sd1x_path = '/share/bcy/cache/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/133a221b8aa7292a167afc5127cb63fb5005638b' if os.path.isdir('/share/bcy/cache/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/133a221b8aa7292a167afc5127cb63fb5005638b') else "CompVis/stable-diffusion-v1-4"
    print(f"加载 StableDiffusion 模型: {sd1x_path}")
    
    # 加载 StableDiffusionMIGCPipeline
    print("load sd:")
    pipe = StableDiffusionMIGCPipeline.from_pretrained(sd1x_path)
    pipe.attention_store = AttentionStore()
    
    # 导入并加载 MIGC 权重
    print("load migc")
    from migc.migc_utils import load_migc
    load_migc(pipe.unet, pipe.attention_store, migc_ckpt_path_all, attn_processor=MIGCProcessor)
    
    # 确保模型和 attention_store 被正确加载
    assert pipe.unet is not None, "unet 模型未正确加载!"
    assert pipe.attention_store is not None, "attention_store 未正确加载!"
    
    # 转移到 CUDA
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print("使用 CUDA 设备")
    else:
        device = torch.device("cpu")
        print("使用 CPU")
    
    pipe = pipe.to(device)
    
    # 设置调度器
    pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
    
    return pipe


def create_simple_prompt(input_str):
    # 先将输入字符串按分号分割,并去掉空字符串
    objects = [obj for obj in input_str.split(';') if obj.strip()]
    
    # 创建详细描述字符串
    prompt_description = "masterpiece, best quality, " + ", ".join(objects)
    
    # 创建最终结构
    prompt_final = [[prompt_description] + objects]
    
    return prompt_final


def inference_image(pipe, prompt, grounding_instruction, state):
    print(prompt)
    print(grounding_instruction)
    bbox = state['boxes']
    print(bbox)
    bbox = normalize_bbox(bbox, 600, 600)
    print(bbox)
    simple_prompt = create_simple_prompt(grounding_instruction)
    print(simple_prompt)
    negative_prompt = 'worst quality, low quality, bad anatomy, watermark, text, blurry'
    seed = 7351007268695528845
    seed_everything(seed)
    print("Start inference: ")
    image = pipe(simple_prompt, bbox, num_inference_steps=50, guidance_scale=7.5, 
                    MIGCsteps=25, aug_phase_with_and=False, negative_prompt=negative_prompt).images[0]
    return image



if __name__ == "__main__":
    prompt_final = [['masterpiece, best quality,black colored ball,gray colored cat,white colored  bed,\
                     green colored plant,red colored teddy bear,blue colored wall,brown colored vase,orange colored book,\
                     yellow colored hat', 'black colored ball', 'gray colored cat', 'white colored  bed', 'green colored plant', \
                        'red colored teddy bear', 'blue colored wall', 'brown colored vase', 'orange colored book', 'yellow colored hat']]
    
    bboxes = [[[0.3125, 0.609375, 0.625, 0.875], [0.5625, 0.171875, 0.984375, 0.6875], \
               [0.0, 0.265625, 0.984375, 0.984375], [0.0, 0.015625, 0.21875, 0.328125], \
                [0.171875, 0.109375, 0.546875, 0.515625], [0.234375, 0.0, 1.0, 0.3125], \
                    [0.71875, 0.625, 0.953125, 0.921875], [0.0625, 0.484375, 0.359375, 0.8125], \
                        [0.609375, 0.09375, 0.90625, 0.28125]]]
    image = inference_single_image("a cat", prompt_final, bboxes)
    image.save("output.png")
    print("done")