Spaces:
Runtime error
Runtime error
import json | |
import os | |
import cv2 | |
import random | |
import numpy as np | |
import gradio as gr | |
import torch | |
from zhipuai import ZhipuAI | |
from pytorch_lightning import seed_everything | |
from pprint import pprint | |
from PIL import Image, ImageDraw, ImageFont | |
from diffusers import ( | |
ControlNetModel, | |
StableDiffusionControlNetPipeline, | |
) | |
from diffusers import ( | |
DDIMScheduler, | |
PNDMScheduler, | |
EulerAncestralDiscreteScheduler, | |
DPMSolverMultistepScheduler, | |
EulerDiscreteScheduler, | |
LMSDiscreteScheduler, | |
HeunDiscreteScheduler | |
) | |
from controlnet_aux import ( | |
PidiNetDetector, | |
HEDdetector | |
) | |
BBOX_MAX_NUM = 8 | |
BBOX_INI_NUM = 0 | |
MAX_LENGTH = 20 | |
device = 'cuda' | |
pipeline = None | |
pre_pipeline = None | |
model_root = os.getenv('REPO_ROOT') | |
scheduler_root = f'{model_root}/Scheduler' | |
model_list =[ | |
'JoyType.v1.0', 'RevAnimated-animation-动漫', 'GhostMix-animation-动漫', | |
'rpg.v5-fantasy_realism-奇幻写实', 'midjourneyPapercut-origami-折纸版画', | |
'dvarchExterior-architecture-建筑', 'awpainting.v13-portrait-人物肖像' | |
] | |
chn_example_dict = { | |
'漂亮的风景照,很多山峰,清澈的湖水': 'beautiful landscape, many peaks, clear lake', | |
'画有玫瑰的卡片,明亮的背景': 'a card with roses, bright background', | |
'一张关于健康教育的卡片,上面有一些文字,有一些食物图标,背景里有一些水果喝饮料的图标,且背景是模糊的': \ | |
'a card for health education, with some writings on it, ' | |
'food icons on the card, some fruits and drinking in the background, blur background ' | |
} | |
match_dict = { | |
'JoyType.v1.0': 'JoyType-v1-1M', | |
'RevAnimated-animation-动漫': 'rev-animated-v1-2-2', | |
'GhostMix-animation-动漫': 'GhostMix_V2.0', | |
'rpg.v5-fantasy_realism-奇幻写实': 'rpg_v5', | |
'midjourneyPapercut-origami-折纸版画': 'midjourneyPapercut_v1', | |
'dvarchExterior-architecture-建筑': 'dvarchExterior', | |
'awpainting.v13-portrait-人物肖像': 'awpainting_v13' | |
} | |
font_list = [ | |
'CHN-华文行楷', | |
'CHN-华文新魏', | |
'CHN-清松手写体', | |
'CHN-巴蜀墨迹', | |
'CHN-雷盖体', | |
'CHN-演示夏行楷', | |
'CHN-鸿雷板书简体', | |
'CHN-斑马字类', | |
'CHN-青柳隶书', | |
'CHN-辰宇落雁体', | |
'CHN-宅家麦克笔', | |
'ENG-Playwrite', | |
'ENG-Okesip', | |
'ENG-Shrikhand', | |
'ENG-Nextstep', | |
'ENG-Filthyrich', | |
'ENG-BebasNeue', | |
'ENG-Gloock', | |
'ENG-Lemon', | |
'RUS-Automatons', | |
'RUS-MKyrill', | |
'RUS-Alice', | |
'RUS-Caveat', | |
'KOR-ChosunGs', | |
'KOR-Dongle', | |
'KOR-GodoMaum', | |
'KOR-UnDotum', | |
'JPN-GlTsukiji', | |
'JPN-Aoyagireisyosimo', | |
'JPN-KouzanMouhitu', | |
'JPN-Otomanopee' | |
] | |
def change_settings(base_model): | |
if base_model == model_list[0]: | |
return gr.update(value=20), gr.update(value=7.5), gr.update(value='PNDM') | |
elif base_model == model_list[1]: | |
return gr.update(value=30), gr.update(value=8.5), gr.update(value='Euler') | |
elif base_model == model_list[2]: | |
return gr.update(value=32), gr.update(value=8.5), gr.update(value='Euler') | |
elif base_model == model_list[3]: | |
return gr.update(value=20), gr.update(value=7.5), gr.update(value='DPM') | |
elif base_model == model_list[4]: | |
return gr.update(value=25), gr.update(value=6.5), gr.update(value='Euler') | |
elif base_model == model_list[5]: | |
return gr.update(value=25), gr.update(value=8.5), gr.update(value='Euler') | |
elif base_model == model_list[6]: | |
return gr.update(value=25), gr.update(value=7), gr.update(value='DPM') | |
else: | |
pass | |
def update_box_num(choice): | |
update_list_1 = [] # checkbox | |
update_list_2 = [] # font | |
update_list_3 = [] # text | |
update_list_4 = [] # bounding box | |
for i in range(BBOX_MAX_NUM): | |
if i < choice: | |
update_list_1.append(gr.update(value=True)) | |
update_list_2.append(gr.update(visible=True)) | |
update_list_3.append(gr.update(visible=True)) | |
update_list_4.extend([gr.update(visible=False) for _ in range(4)]) | |
else: | |
update_list_1.append(gr.update(value=False)) | |
update_list_2.append(gr.update(visible=False, value='CHN-华文行楷')) | |
update_list_3.append(gr.update(visible=False, value='')) | |
update_list_4.extend([ | |
gr.update(visible=False, value=0.4), | |
gr.update(visible=False, value=0.4), | |
gr.update(visible=False, value=0.2), | |
gr.update(visible=False, value=0.2) | |
]) | |
return *update_list_1, *update_list_2, *update_list_3, *update_list_4 | |
def load_box_list(example_id, choice): | |
with open(f'templates/{example_id}.json', 'r') as f: | |
info = json.load(f) | |
update_list1 = [] | |
update_list2 = [] | |
update_list3 = [] | |
update_list4 = [] | |
for i in range(BBOX_MAX_NUM): | |
visible = info['visible'][i] | |
pos = info['pos'][i * 4: (i + 1) * 4] | |
update_list1.append(gr.update(value=visible)) | |
update_list2.append(gr.update(value=info['font'][i], visible=visible)) | |
update_list3.append(gr.update(value=info['text'][i], visible=visible)) | |
update_list4.extend([ | |
gr.update(value=pos[0]), | |
gr.update(value=pos[1]), | |
gr.update(value=pos[2]), | |
gr.update(value=pos[3]) | |
]) | |
return *update_list1, *update_list2, \ | |
*update_list3, *update_list4, gr.update(value=-1) | |
def re_edit(): | |
global BBOX_MAX_NUM | |
update_list = [] | |
for i in range(BBOX_MAX_NUM): | |
update_list.extend([gr.update(value=0.4), gr.update(value=0.4), gr.update(value=0.2), | |
gr.update(value=0.2)]) | |
return *update_list, \ | |
gr.Image( | |
value=create_canvas(), | |
label='Rect Position', elem_id='MD-bbox-rect-t2i', | |
show_label=False, visible=True | |
), \ | |
gr.Slider(value=512), gr.Slider(value=512) | |
def resize_w(w, img): | |
return cv2.resize(img, (w, img.shape[0])) | |
def resize_h(h, img): | |
return cv2.resize(img, (img.shape[1], h)) | |
def create_canvas(w=512, h=512, c=3, line=5): | |
image = np.full((h, w, c), 200, dtype=np.uint8) | |
for i in range(h): | |
if i % (w // line) == 0: | |
image[i, :, :] = 150 | |
for j in range(w): | |
if j % (w // line) == 0: | |
image[:, j, :] = 150 | |
image[h // 2 - 8:h // 2 + 8, w // 2 - 8:w // 2 + 8, :] = [200, 0, 0] | |
return image | |
def canny(img): | |
low_threshold = 64 | |
high_threshold = 100 | |
img = cv2.Canny(img, low_threshold, high_threshold) | |
img = img[:, :, None] | |
img = np.concatenate([img, img, img], axis=2) | |
return Image.fromarray(img) | |
def judge_overlap(coord_list1, coord_list2): | |
judge = coord_list1[0] < coord_list2[2] and coord_list1[2] > coord_list2[0] \ | |
and coord_list1[1] < coord_list2[3] and coord_list1[3] > coord_list2[1] | |
return judge | |
def parse_render_list(box_list, shape, box_num): | |
width = shape[0] | |
height = shape[1] | |
polygons = [] | |
font_names = [] | |
texts = [] | |
valid_list = box_list[:box_num] | |
pos_list = box_list[box_num: 5 * box_num] | |
font_name_list = box_list[5 * box_num: 6 * box_num] | |
text_list = box_list[6 * box_num: 7 * box_num] | |
empty_flag = False | |
print(font_name_list, text_list) | |
for i, valid in enumerate(valid_list): | |
if valid: | |
pos = pos_list[i * 4: (i + 1) * 4] | |
top_left_x = int(pos[0] * width) | |
top_left_y = int(pos[1] * height) | |
w = int(pos[2] * width) | |
h = int(pos[3] * height) | |
font_name = str(font_name_list[i]) | |
text = str(text_list[i]) | |
if text == '': | |
empty_flag = True | |
text = 'JoyType' | |
if w <= 0 or h <= 0: | |
gr.Warning(f'Area of the box{i + 1} cannot be zero!') | |
return [], False | |
polygon = [ | |
top_left_x, | |
top_left_y, | |
w, h | |
] | |
try: | |
assert font_name in font_list | |
font_name = font_name.split('-')[-1] | |
except Exception as e: | |
gr.Warning('Please choose a correct font!') | |
return [], False | |
polygons.append(polygon) | |
font_names.append(font_name.split('-')[-1]) | |
texts.append(text) | |
if empty_flag: | |
gr.Warning('Null strings will be filled automatically!') | |
for i in range(len(polygons)): | |
for j in range(i + 1, len(polygons)): | |
if judge_overlap( | |
[polygons[i][0], polygons[i][1], polygons[i][0] + polygons[i][2], polygons[i][1] + polygons[i][3]], | |
[polygons[j][0], polygons[j][1], polygons[j][0] + polygons[j][2], polygons[j][1] + polygons[j][3]] | |
): | |
gr.Warning('Find overlapping boxes!') | |
return [], False | |
render_list = [] | |
for i in range(len(polygons)): | |
text_dict = {} | |
text_dict['text'] = texts[i] | |
text_dict['polygon'] = polygons[i] | |
text_dict['font_name'] = font_names[i] | |
render_list.append(text_dict) | |
return render_list, True | |
def render_all_text(render_list, shape, threshold=512): | |
width = shape[0] | |
height = shape[1] | |
board = Image.new('RGB', (width, height), 'black') | |
for text_dict in render_list: | |
text = text_dict['text'] | |
polygon = text_dict['polygon'] | |
font_name = text_dict['font_name'] | |
if len(text) > MAX_LENGTH: | |
text = text[:MAX_LENGTH] | |
gr.Warning(f'{text}... exceeds the maximum length {MAX_LENGTH} and has been cropped.') | |
w, h = polygon[2:] | |
vert = True if w < h else False | |
image4ratio = Image.new('RGB', (1024, 1024), 'black') | |
draw = ImageDraw.Draw(image4ratio) | |
try: | |
font = ImageFont.truetype(f'./font/{font_name}.ttf', encoding='utf-8', size=50) | |
except FileNotFoundError: | |
font = ImageFont.truetype(f'./font/{font_name}.otf', encoding='utf-8', size=50) | |
if not vert: | |
draw.text(xy=(0, 0), text=text, font=font, fill='white') | |
_, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font) | |
_th += 1 | |
else: | |
_tw, y_c = 0, 0 | |
for c in text: | |
draw.text(xy=(0, y_c), text=c, font=font, fill='white') | |
_l, _t, _r, _b = font.getbbox(c) | |
_tw = max(_tw, _r - _l) | |
y_c += _b | |
_th = y_c + 1 | |
ratio = (_th * w) / (_tw * h) | |
text_img = image4ratio.crop((0, 0, _tw, _th)) | |
x_offset, y_offset = 0, 0 | |
if 0.8 <= ratio <= 1.2: | |
text_img = text_img.resize((w, h)) | |
elif ratio < 0.75: | |
resize_h = int(_th * (w / _tw)) | |
text_img = text_img.resize((w, resize_h)) | |
y_offset = (h - resize_h) // 2 | |
else: | |
resize_w = int(_tw * (h / _th)) | |
text_img = text_img.resize((resize_w, h)) | |
x_offset = (w - resize_w) // 2 | |
board.paste(text_img, (polygon[0] + x_offset, polygon[1] + y_offset)) | |
return board | |
def load_pipeline(model_name, scheduler_name): | |
controlnet_path = os.path.join(model_root, f'{match_dict["JoyType.v1.0"]}') | |
model_path = os.path.join(model_root, model_name) | |
scheduler_name = scheduler_name.lower() | |
if scheduler_name == 'pndm': | |
scheduler = PNDMScheduler.from_pretrained(scheduler_root, subfolder='pndm') | |
if scheduler_name == 'lms': | |
scheduler = LMSDiscreteScheduler.from_pretrained(scheduler_root, subfolder='lms') | |
if scheduler_name == 'euler': | |
scheduler = EulerDiscreteScheduler.from_pretrained(scheduler_root, subfolder='euler') | |
if scheduler_name == 'dpm': | |
scheduler = DPMSolverMultistepScheduler.from_pretrained(scheduler_root, subfolder='dpm') | |
if scheduler_name == 'ddim': | |
scheduler = DDIMScheduler.from_pretrained(scheduler_root, subfolder='ddim') | |
if scheduler_name == 'heun': | |
scheduler = HeunDiscreteScheduler.from_pretrained(scheduler_root, subfolder='heun') | |
if scheduler_name == 'euler-ancestral': | |
scheduler = EulerAncestralDiscreteScheduler.from_pretrained(scheduler_root, subfolder='euler-ancestral') | |
controlnet = ControlNetModel.from_pretrained( | |
controlnet_path, | |
subfolder='controlnet', | |
torch_dtype=torch.float32 | |
) | |
pipeline = StableDiffusionControlNetPipeline.from_pretrained( | |
model_path, | |
scheduler=scheduler, | |
controlnet=controlnet, | |
torch_dtype=torch.float32, | |
).to(device) | |
return pipeline | |
def preprocess_prompt(prompt): | |
client = ZhipuAI(api_key=os.getenv('ZHIPU_API_KEY')) | |
response = client.chat.completions.create( | |
model="glm-4-0520", | |
messages=[ | |
{ | |
'role': 'system', | |
'content': ''' | |
Stable Diffusion是一款利用深度学习的文生图模型,支持通过使用提示词来产生新的图像,描述要包含或省略的元素。 | |
我在这里引入Stable Diffusion算法中的Prompt概念,又被称为提示符。这里的Prompt通常可以用来描述图像, | |
他由普通常见的单词构成,最好是可以在数据集来源站点找到的著名标签(比如Ddanbooru)。 | |
下面我将说明Prompt的生出步骤,这里的Prompt主要用于描述人物。在Prompt的生成中,你需要通过提示词来描述 人物属性,主题,外表,情绪,衣服,姿势,视角,动作,背景。 | |
用英语单词或短语甚至自然语言的标签来描述,并不局限于我给你的单词。然后将你想要的相似的提示词组合在一起,请使用英文半角,做分隔符,每个提示词不要带引号,并将这些按从最重要到最不重要的顺序 排列。 | |
另外请您注意,永远在每个 Prompt的前面加上引号里的内容, | |
“(((best quality))),(((ultra detailed))),(((masterpiece))),illustration,” 这是高质量的标志。 | |
人物属性中,1girl表示你生成了一个女孩,2girls表示生成了两个女孩,一次。另外再注意,Prompt中不能带有-和_。 | |
可以有空格和自然语言,但不要太多,单词不能重复。只返回Prompt。 | |
''' | |
}, | |
{ | |
'role': 'user', | |
'content': prompt | |
} | |
], | |
temperature=0.5, | |
max_tokens=2048, | |
top_p=1, | |
stream=False, | |
) | |
if response: | |
glm = [] | |
glm_return_list = response.choices | |
for item in glm_return_list: | |
glm.append(item.message.content) | |
return {'flag': 1, 'data': glm} | |
else: | |
return {'flag': 0, 'data': {}} | |
def process( | |
num_samples, | |
a_prompt, | |
n_prompt, | |
conditioning_scale, | |
cfg_scale, | |
inference_steps, | |
seed, | |
usr_prompt, | |
rect_img, | |
base_model, | |
scheduler_name, | |
box_num, | |
*box_list | |
): | |
if usr_prompt == '': | |
gr.Warning('Must input a prompt!') | |
return None, gr.Markdown('error') | |
if seed == -1: | |
seed = random.randint(0, 2147483647) | |
seed_everything(seed) | |
# Support Chinese Input | |
if usr_prompt in chn_example_dict.keys(): | |
usr_prompt = chn_example_dict[usr_prompt] | |
else: | |
for ch in usr_prompt: | |
if '\u4e00' <= ch <= '\u9fff': | |
data = preprocess_prompt(usr_prompt) | |
if data['flag'] == 1: | |
usr_prompt = data['data'][0][1: -1] | |
else: | |
gr.Warning('Something went wrong while translating your prompt, please try again.') | |
return None, gr.Markdown('error') | |
break | |
shape = (rect_img.shape[1], rect_img.shape[0]) | |
render_list, flag = parse_render_list(box_list, shape, box_num) | |
if flag: | |
render_img = render_all_text(render_list, shape) | |
else: | |
return None, gr.Markdown('error') | |
model_name = match_dict[base_model] | |
render_img = canny(np.array(render_img)) | |
w, h = render_img.size | |
global pipeline, pre_pipeline | |
if pre_pipeline != model_name or pipeline is None: | |
pre_pipeline = model_name | |
pipeline = load_pipeline(model_name, scheduler_name) | |
batch_render_img = [render_img for _ in range(num_samples)] | |
batch_prompt = [f'{usr_prompt}, {a_prompt}' for _ in range(num_samples)] | |
batch_n_prompt = [n_prompt for _ in range(num_samples)] | |
images = pipeline( | |
batch_prompt, | |
negative_prompt=batch_n_prompt, | |
image=batch_render_img, | |
controlnet_conditioning_scale=float(conditioning_scale), | |
guidance_scale=float(cfg_scale), | |
width=w, | |
height=h, | |
num_inference_steps=int(inference_steps), | |
).images | |
return images, gr.Markdown(f'{seed}, {usr_prompt}, {box_list}') | |
def draw_example(box_list, color, id): | |
board = Image.fromarray(create_canvas()) | |
w, h = board.size | |
draw = ImageDraw.Draw(board, mode='RGBA') | |
visible = box_list[:BBOX_MAX_NUM] | |
pos = box_list[BBOX_MAX_NUM: 5 * BBOX_MAX_NUM] | |
font = box_list[5 * BBOX_MAX_NUM: 6 * BBOX_MAX_NUM] | |
text = box_list[6 * BBOX_MAX_NUM:] | |
info = { | |
'visible': list(visible), | |
'pos': list(pos), | |
'font': list(font), | |
'text': list(text) | |
} | |
with open(f'templates/{id}.json', 'w') as f: | |
json.dump(info, f) | |
for i in range(BBOX_MAX_NUM): | |
if visible[i] is True: | |
polygon = pos[i * 4: (i + 1) * 4] | |
print(polygon) | |
left = w * polygon[0] | |
top = h * polygon[1] | |
right = left + w * polygon[2] | |
bottom = top + h * polygon[3] | |
draw.rectangle([left, top, right, bottom], outline=color[i][0], fill=color[i][1], width=3) | |
board.save(f'./examples/{id}.png') | |
if __name__ == '__main__': | |
pass | |