Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import torch | |
migc_path = os.path.dirname(os.path.abspath(__file__)) | |
print(migc_path) | |
if migc_path not in sys.path: | |
sys.path.append(migc_path) | |
import yaml | |
from diffusers import EulerDiscreteScheduler | |
from migc.migc_utils import seed_everything | |
from migc.migc_pipeline import StableDiffusionMIGCPipeline, MIGCProcessor, AttentionStore | |
def normalize_bbox(bboxes, img_width, img_height): | |
normalized_bboxes = [] | |
for box in bboxes: | |
x_min, y_min, x_max, y_max = box | |
x_min = x_min / img_width | |
y_min = y_min / img_height | |
x_max = x_max / img_width | |
y_max = y_max / img_height | |
normalized_bboxes.append([x_min, y_min, x_max, y_max]) | |
return [normalized_bboxes] | |
def create_simple_prompt(input_str): | |
# 先将输入字符串按分号分割,并去掉空字符串 | |
objects = [obj for obj in input_str.split(';') if obj.strip()] | |
# 创建详细描述字符串 | |
prompt_description = "masterpiece, best quality, " + ", ".join(objects) | |
# 创建最终结构 | |
prompt_final = [[prompt_description] + objects] | |
return prompt_final | |
def inference_single_image(prompt, grounding_instruction, state): | |
print(prompt) | |
print(grounding_instruction) | |
bbox = state['boxes'] | |
print(bbox) | |
bbox = normalize_bbox(bbox, 600, 600) | |
print(bbox) | |
simple_prompt = create_simple_prompt(grounding_instruction) | |
print(simple_prompt) | |
migc_ckpt_path = 'pretrained_weights/MIGC_SD14.ckpt' | |
migc_ckpt_path_all = os.path.join(migc_path, migc_ckpt_path) | |
print(migc_ckpt_path_all) | |
assert os.path.isfile(migc_ckpt_path_all), "Please download the ckpt of migc and put it in the pretrained_weighrs/ folder!" | |
sd1x_path = '/share/bcy/cache/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/133a221b8aa7292a167afc5127cb63fb5005638b' if os.path.isdir('/share/bcy/cache/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/133a221b8aa7292a167afc5127cb63fb5005638b') else "CompVis/stable-diffusion-v1-4" | |
# MIGC is a plug-and-play controller. | |
# You can go to https://civitai.com/search/models?baseModel=SD%201.4&baseModel=SD%201.5&sortBy=models_v5 find a base model with better generation ability to achieve better creations. | |
# Construct MIGC pipeline | |
pipe = StableDiffusionMIGCPipeline.from_pretrained( | |
sd1x_path) | |
pipe.attention_store = AttentionStore() | |
from migc.migc_utils import load_migc | |
load_migc(pipe.unet , pipe.attention_store, | |
migc_ckpt_path_all, attn_processor=MIGCProcessor) | |
pipe = pipe.to("cuda") | |
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) | |
# prompt_final = [['masterpiece, best quality,black colored ball,gray colored cat,white colored bed,\ | |
# green colored plant,red colored teddy bear,blue colored wall,brown colored vase,orange colored book,\ | |
# yellow colored hat', 'black colored ball', 'gray colored cat', 'white colored bed', 'green colored plant', \ | |
# 'red colored teddy bear', 'blue colored wall', 'brown colored vase', 'orange colored book', 'yellow colored hat']] | |
# bboxes = [[[0.3125, 0.609375, 0.625, 0.875], [0.5625, 0.171875, 0.984375, 0.6875], \ | |
# [0.0, 0.265625, 0.984375, 0.984375], [0.0, 0.015625, 0.21875, 0.328125], \ | |
# [0.171875, 0.109375, 0.546875, 0.515625], [0.234375, 0.0, 1.0, 0.3125], \ | |
# [0.71875, 0.625, 0.953125, 0.921875], [0.0625, 0.484375, 0.359375, 0.8125], \ | |
# [0.609375, 0.09375, 0.90625, 0.28125]]] | |
negative_prompt = 'worst quality, low quality, bad anatomy, watermark, text, blurry' | |
seed = 7351007268695528845 | |
seed_everything(seed) | |
print("Start inference: ") | |
image = pipe(simple_prompt, bbox, num_inference_steps=50, guidance_scale=7.5, | |
MIGCsteps=25, aug_phase_with_and=False, negative_prompt=negative_prompt).images[0] | |
return image | |
# def MIGC_Pipe(): | |
# migc_ckpt_path = 'pretrained_weights/MIGC_SD14.ckpt' | |
# migc_ckpt_path_all = os.path.join(migc_path, migc_ckpt_path) | |
# print(migc_ckpt_path_all) | |
# assert os.path.isfile(migc_ckpt_path_all), "Please download the ckpt of migc and put it in the pretrained_weighrs/ folder!" | |
# sd1x_path = '/share/bcy/cache/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/133a221b8aa7292a167afc5127cb63fb5005638b' if os.path.isdir('/share/bcy/cache/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/133a221b8aa7292a167afc5127cb63fb5005638b') else "CompVis/stable-diffusion-v1-4" | |
# pipe = StableDiffusionMIGCPipeline.from_pretrained( | |
# sd1x_path) | |
# pipe.attention_store = AttentionStore() | |
# from migc.migc_utils import load_migc | |
# load_migc(pipe.unet , pipe.attention_store, | |
# migc_ckpt_path_all, attn_processor=MIGCProcessor) | |
# pipe = pipe.to("cuda") | |
# pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) | |
# return pipe | |
def MIGC_Pipe(): | |
migc_ckpt_path = 'pretrained_weights/MIGC_SD14.ckpt' | |
migc_ckpt_path_all = os.path.join(migc_path, migc_ckpt_path) | |
print(f"加载 MIGC 权重文件路径: {migc_ckpt_path_all}") | |
assert os.path.isfile(migc_ckpt_path_all), f"请下载 MIGC 的 ckpt 文件并将其放在 'pretrained_weights/' 文件夹中: {migc_ckpt_path_all}" | |
sd1x_path = '/share/bcy/cache/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/133a221b8aa7292a167afc5127cb63fb5005638b' if os.path.isdir('/share/bcy/cache/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/133a221b8aa7292a167afc5127cb63fb5005638b') else "CompVis/stable-diffusion-v1-4" | |
print(f"加载 StableDiffusion 模型: {sd1x_path}") | |
# 加载 StableDiffusionMIGCPipeline | |
print("load sd:") | |
pipe = StableDiffusionMIGCPipeline.from_pretrained(sd1x_path) | |
pipe.attention_store = AttentionStore() | |
# 导入并加载 MIGC 权重 | |
print("load migc") | |
from migc.migc_utils import load_migc | |
load_migc(pipe.unet, pipe.attention_store, migc_ckpt_path_all, attn_processor=MIGCProcessor) | |
# 确保模型和 attention_store 被正确加载 | |
assert pipe.unet is not None, "unet 模型未正确加载!" | |
assert pipe.attention_store is not None, "attention_store 未正确加载!" | |
# 转移到 CUDA | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
print("使用 CUDA 设备") | |
else: | |
device = torch.device("cpu") | |
print("使用 CPU") | |
pipe = pipe.to(device) | |
# 设置调度器 | |
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) | |
return pipe | |
def create_simple_prompt(input_str): | |
# 先将输入字符串按分号分割,并去掉空字符串 | |
objects = [obj for obj in input_str.split(';') if obj.strip()] | |
# 创建详细描述字符串 | |
prompt_description = "masterpiece, best quality, " + ", ".join(objects) | |
# 创建最终结构 | |
prompt_final = [[prompt_description] + objects] | |
return prompt_final | |
def inference_image(pipe, prompt, grounding_instruction, state): | |
print(prompt) | |
print(grounding_instruction) | |
bbox = state['boxes'] | |
print(bbox) | |
bbox = normalize_bbox(bbox, 600, 600) | |
print(bbox) | |
simple_prompt = create_simple_prompt(grounding_instruction) | |
print(simple_prompt) | |
negative_prompt = 'worst quality, low quality, bad anatomy, watermark, text, blurry' | |
seed = 7351007268695528845 | |
seed_everything(seed) | |
print("Start inference: ") | |
image = pipe(simple_prompt, bbox, num_inference_steps=50, guidance_scale=7.5, | |
MIGCsteps=25, aug_phase_with_and=False, negative_prompt=negative_prompt).images[0] | |
return image | |
if __name__ == "__main__": | |
prompt_final = [['masterpiece, best quality,black colored ball,gray colored cat,white colored bed,\ | |
green colored plant,red colored teddy bear,blue colored wall,brown colored vase,orange colored book,\ | |
yellow colored hat', 'black colored ball', 'gray colored cat', 'white colored bed', 'green colored plant', \ | |
'red colored teddy bear', 'blue colored wall', 'brown colored vase', 'orange colored book', 'yellow colored hat']] | |
bboxes = [[[0.3125, 0.609375, 0.625, 0.875], [0.5625, 0.171875, 0.984375, 0.6875], \ | |
[0.0, 0.265625, 0.984375, 0.984375], [0.0, 0.015625, 0.21875, 0.328125], \ | |
[0.171875, 0.109375, 0.546875, 0.515625], [0.234375, 0.0, 1.0, 0.3125], \ | |
[0.71875, 0.625, 0.953125, 0.921875], [0.0625, 0.484375, 0.359375, 0.8125], \ | |
[0.609375, 0.09375, 0.90625, 0.28125]]] | |
image = inference_single_image("a cat", prompt_final, bboxes) | |
image.save("output.png") | |
print("done") |