Control_Ability_Arena / model_bbox /MIGC /inference_single_image.py
Bbmyy
first commit
c92c0ec
import os
import sys
import torch
migc_path = os.path.dirname(os.path.abspath(__file__))
print(migc_path)
if migc_path not in sys.path:
sys.path.append(migc_path)
import yaml
from diffusers import EulerDiscreteScheduler
from migc.migc_utils import seed_everything
from migc.migc_pipeline import StableDiffusionMIGCPipeline, MIGCProcessor, AttentionStore
def normalize_bbox(bboxes, img_width, img_height):
normalized_bboxes = []
for box in bboxes:
x_min, y_min, x_max, y_max = box
x_min = x_min / img_width
y_min = y_min / img_height
x_max = x_max / img_width
y_max = y_max / img_height
normalized_bboxes.append([x_min, y_min, x_max, y_max])
return [normalized_bboxes]
def create_simple_prompt(input_str):
# 先将输入字符串按分号分割,并去掉空字符串
objects = [obj for obj in input_str.split(';') if obj.strip()]
# 创建详细描述字符串
prompt_description = "masterpiece, best quality, " + ", ".join(objects)
# 创建最终结构
prompt_final = [[prompt_description] + objects]
return prompt_final
def inference_single_image(prompt, grounding_instruction, state):
print(prompt)
print(grounding_instruction)
bbox = state['boxes']
print(bbox)
bbox = normalize_bbox(bbox, 600, 600)
print(bbox)
simple_prompt = create_simple_prompt(grounding_instruction)
print(simple_prompt)
migc_ckpt_path = 'pretrained_weights/MIGC_SD14.ckpt'
migc_ckpt_path_all = os.path.join(migc_path, migc_ckpt_path)
print(migc_ckpt_path_all)
assert os.path.isfile(migc_ckpt_path_all), "Please download the ckpt of migc and put it in the pretrained_weighrs/ folder!"
sd1x_path = '/share/bcy/cache/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/133a221b8aa7292a167afc5127cb63fb5005638b' if os.path.isdir('/share/bcy/cache/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/133a221b8aa7292a167afc5127cb63fb5005638b') else "CompVis/stable-diffusion-v1-4"
# MIGC is a plug-and-play controller.
# You can go to https://civitai.com/search/models?baseModel=SD%201.4&baseModel=SD%201.5&sortBy=models_v5 find a base model with better generation ability to achieve better creations.
# Construct MIGC pipeline
pipe = StableDiffusionMIGCPipeline.from_pretrained(
sd1x_path)
pipe.attention_store = AttentionStore()
from migc.migc_utils import load_migc
load_migc(pipe.unet , pipe.attention_store,
migc_ckpt_path_all, attn_processor=MIGCProcessor)
pipe = pipe.to("cuda")
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
# prompt_final = [['masterpiece, best quality,black colored ball,gray colored cat,white colored bed,\
# green colored plant,red colored teddy bear,blue colored wall,brown colored vase,orange colored book,\
# yellow colored hat', 'black colored ball', 'gray colored cat', 'white colored bed', 'green colored plant', \
# 'red colored teddy bear', 'blue colored wall', 'brown colored vase', 'orange colored book', 'yellow colored hat']]
# bboxes = [[[0.3125, 0.609375, 0.625, 0.875], [0.5625, 0.171875, 0.984375, 0.6875], \
# [0.0, 0.265625, 0.984375, 0.984375], [0.0, 0.015625, 0.21875, 0.328125], \
# [0.171875, 0.109375, 0.546875, 0.515625], [0.234375, 0.0, 1.0, 0.3125], \
# [0.71875, 0.625, 0.953125, 0.921875], [0.0625, 0.484375, 0.359375, 0.8125], \
# [0.609375, 0.09375, 0.90625, 0.28125]]]
negative_prompt = 'worst quality, low quality, bad anatomy, watermark, text, blurry'
seed = 7351007268695528845
seed_everything(seed)
print("Start inference: ")
image = pipe(simple_prompt, bbox, num_inference_steps=50, guidance_scale=7.5,
MIGCsteps=25, aug_phase_with_and=False, negative_prompt=negative_prompt).images[0]
return image
# def MIGC_Pipe():
# migc_ckpt_path = 'pretrained_weights/MIGC_SD14.ckpt'
# migc_ckpt_path_all = os.path.join(migc_path, migc_ckpt_path)
# print(migc_ckpt_path_all)
# assert os.path.isfile(migc_ckpt_path_all), "Please download the ckpt of migc and put it in the pretrained_weighrs/ folder!"
# sd1x_path = '/share/bcy/cache/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/133a221b8aa7292a167afc5127cb63fb5005638b' if os.path.isdir('/share/bcy/cache/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/133a221b8aa7292a167afc5127cb63fb5005638b') else "CompVis/stable-diffusion-v1-4"
# pipe = StableDiffusionMIGCPipeline.from_pretrained(
# sd1x_path)
# pipe.attention_store = AttentionStore()
# from migc.migc_utils import load_migc
# load_migc(pipe.unet , pipe.attention_store,
# migc_ckpt_path_all, attn_processor=MIGCProcessor)
# pipe = pipe.to("cuda")
# pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
# return pipe
def MIGC_Pipe():
migc_ckpt_path = 'pretrained_weights/MIGC_SD14.ckpt'
migc_ckpt_path_all = os.path.join(migc_path, migc_ckpt_path)
print(f"加载 MIGC 权重文件路径: {migc_ckpt_path_all}")
assert os.path.isfile(migc_ckpt_path_all), f"请下载 MIGC 的 ckpt 文件并将其放在 'pretrained_weights/' 文件夹中: {migc_ckpt_path_all}"
sd1x_path = '/share/bcy/cache/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/133a221b8aa7292a167afc5127cb63fb5005638b' if os.path.isdir('/share/bcy/cache/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/133a221b8aa7292a167afc5127cb63fb5005638b') else "CompVis/stable-diffusion-v1-4"
print(f"加载 StableDiffusion 模型: {sd1x_path}")
# 加载 StableDiffusionMIGCPipeline
print("load sd:")
pipe = StableDiffusionMIGCPipeline.from_pretrained(sd1x_path)
pipe.attention_store = AttentionStore()
# 导入并加载 MIGC 权重
print("load migc")
from migc.migc_utils import load_migc
load_migc(pipe.unet, pipe.attention_store, migc_ckpt_path_all, attn_processor=MIGCProcessor)
# 确保模型和 attention_store 被正确加载
assert pipe.unet is not None, "unet 模型未正确加载!"
assert pipe.attention_store is not None, "attention_store 未正确加载!"
# 转移到 CUDA
if torch.cuda.is_available():
device = torch.device("cuda")
print("使用 CUDA 设备")
else:
device = torch.device("cpu")
print("使用 CPU")
pipe = pipe.to(device)
# 设置调度器
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
return pipe
def create_simple_prompt(input_str):
# 先将输入字符串按分号分割,并去掉空字符串
objects = [obj for obj in input_str.split(';') if obj.strip()]
# 创建详细描述字符串
prompt_description = "masterpiece, best quality, " + ", ".join(objects)
# 创建最终结构
prompt_final = [[prompt_description] + objects]
return prompt_final
def inference_image(pipe, prompt, grounding_instruction, state):
print(prompt)
print(grounding_instruction)
bbox = state['boxes']
print(bbox)
bbox = normalize_bbox(bbox, 600, 600)
print(bbox)
simple_prompt = create_simple_prompt(grounding_instruction)
print(simple_prompt)
negative_prompt = 'worst quality, low quality, bad anatomy, watermark, text, blurry'
seed = 7351007268695528845
seed_everything(seed)
print("Start inference: ")
image = pipe(simple_prompt, bbox, num_inference_steps=50, guidance_scale=7.5,
MIGCsteps=25, aug_phase_with_and=False, negative_prompt=negative_prompt).images[0]
return image
if __name__ == "__main__":
prompt_final = [['masterpiece, best quality,black colored ball,gray colored cat,white colored bed,\
green colored plant,red colored teddy bear,blue colored wall,brown colored vase,orange colored book,\
yellow colored hat', 'black colored ball', 'gray colored cat', 'white colored bed', 'green colored plant', \
'red colored teddy bear', 'blue colored wall', 'brown colored vase', 'orange colored book', 'yellow colored hat']]
bboxes = [[[0.3125, 0.609375, 0.625, 0.875], [0.5625, 0.171875, 0.984375, 0.6875], \
[0.0, 0.265625, 0.984375, 0.984375], [0.0, 0.015625, 0.21875, 0.328125], \
[0.171875, 0.109375, 0.546875, 0.515625], [0.234375, 0.0, 1.0, 0.3125], \
[0.71875, 0.625, 0.953125, 0.921875], [0.0625, 0.484375, 0.359375, 0.8125], \
[0.609375, 0.09375, 0.90625, 0.28125]]]
image = inference_single_image("a cat", prompt_final, bboxes)
image.save("output.png")
print("done")