JoyType / functions.py
jiangchen16
initial commit
3c3804b
import json
import os
import cv2
import random
import numpy as np
import gradio as gr
import torch
from zhipuai import ZhipuAI
from pytorch_lightning import seed_everything
from pprint import pprint
from PIL import Image, ImageDraw, ImageFont
from diffusers import (
ControlNetModel,
StableDiffusionControlNetPipeline,
)
from diffusers import (
DDIMScheduler,
PNDMScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
HeunDiscreteScheduler
)
from controlnet_aux import (
PidiNetDetector,
HEDdetector
)
BBOX_MAX_NUM = 8
BBOX_INI_NUM = 0
MAX_LENGTH = 20
device = 'cuda'
pipeline = None
pre_pipeline = None
model_root = os.getenv('REPO_ROOT')
scheduler_root = f'{model_root}/Scheduler'
model_list =[
'JoyType.v1.0', 'RevAnimated-animation-动漫', 'GhostMix-animation-动漫',
'rpg.v5-fantasy_realism-奇幻写实', 'midjourneyPapercut-origami-折纸版画',
'dvarchExterior-architecture-建筑', 'awpainting.v13-portrait-人物肖像'
]
chn_example_dict = {
'漂亮的风景照,很多山峰,清澈的湖水': 'beautiful landscape, many peaks, clear lake',
'画有玫瑰的卡片,明亮的背景': 'a card with roses, bright background',
'一张关于健康教育的卡片,上面有一些文字,有一些食物图标,背景里有一些水果喝饮料的图标,且背景是模糊的': \
'a card for health education, with some writings on it, '
'food icons on the card, some fruits and drinking in the background, blur background '
}
match_dict = {
'JoyType.v1.0': 'JoyType-v1-1M',
'RevAnimated-animation-动漫': 'rev-animated-v1-2-2',
'GhostMix-animation-动漫': 'GhostMix_V2.0',
'rpg.v5-fantasy_realism-奇幻写实': 'rpg_v5',
'midjourneyPapercut-origami-折纸版画': 'midjourneyPapercut_v1',
'dvarchExterior-architecture-建筑': 'dvarchExterior',
'awpainting.v13-portrait-人物肖像': 'awpainting_v13'
}
font_list = [
'CHN-华文行楷',
'CHN-华文新魏',
'CHN-清松手写体',
'CHN-巴蜀墨迹',
'CHN-雷盖体',
'CHN-演示夏行楷',
'CHN-鸿雷板书简体',
'CHN-斑马字类',
'CHN-青柳隶书',
'CHN-辰宇落雁体',
'CHN-宅家麦克笔',
'ENG-Playwrite',
'ENG-Okesip',
'ENG-Shrikhand',
'ENG-Nextstep',
'ENG-Filthyrich',
'ENG-BebasNeue',
'ENG-Gloock',
'ENG-Lemon',
'RUS-Automatons',
'RUS-MKyrill',
'RUS-Alice',
'RUS-Caveat',
'KOR-ChosunGs',
'KOR-Dongle',
'KOR-GodoMaum',
'KOR-UnDotum',
'JPN-GlTsukiji',
'JPN-Aoyagireisyosimo',
'JPN-KouzanMouhitu',
'JPN-Otomanopee'
]
def change_settings(base_model):
if base_model == model_list[0]:
return gr.update(value=20), gr.update(value=7.5), gr.update(value='PNDM')
elif base_model == model_list[1]:
return gr.update(value=30), gr.update(value=8.5), gr.update(value='Euler')
elif base_model == model_list[2]:
return gr.update(value=32), gr.update(value=8.5), gr.update(value='Euler')
elif base_model == model_list[3]:
return gr.update(value=20), gr.update(value=7.5), gr.update(value='DPM')
elif base_model == model_list[4]:
return gr.update(value=25), gr.update(value=6.5), gr.update(value='Euler')
elif base_model == model_list[5]:
return gr.update(value=25), gr.update(value=8.5), gr.update(value='Euler')
elif base_model == model_list[6]:
return gr.update(value=25), gr.update(value=7), gr.update(value='DPM')
else:
pass
def update_box_num(choice):
update_list_1 = [] # checkbox
update_list_2 = [] # font
update_list_3 = [] # text
update_list_4 = [] # bounding box
for i in range(BBOX_MAX_NUM):
if i < choice:
update_list_1.append(gr.update(value=True))
update_list_2.append(gr.update(visible=True))
update_list_3.append(gr.update(visible=True))
update_list_4.extend([gr.update(visible=False) for _ in range(4)])
else:
update_list_1.append(gr.update(value=False))
update_list_2.append(gr.update(visible=False, value='CHN-华文行楷'))
update_list_3.append(gr.update(visible=False, value=''))
update_list_4.extend([
gr.update(visible=False, value=0.4),
gr.update(visible=False, value=0.4),
gr.update(visible=False, value=0.2),
gr.update(visible=False, value=0.2)
])
return *update_list_1, *update_list_2, *update_list_3, *update_list_4
def load_box_list(example_id, choice):
with open(f'templates/{example_id}.json', 'r') as f:
info = json.load(f)
update_list1 = []
update_list2 = []
update_list3 = []
update_list4 = []
for i in range(BBOX_MAX_NUM):
visible = info['visible'][i]
pos = info['pos'][i * 4: (i + 1) * 4]
update_list1.append(gr.update(value=visible))
update_list2.append(gr.update(value=info['font'][i], visible=visible))
update_list3.append(gr.update(value=info['text'][i], visible=visible))
update_list4.extend([
gr.update(value=pos[0]),
gr.update(value=pos[1]),
gr.update(value=pos[2]),
gr.update(value=pos[3])
])
return *update_list1, *update_list2, \
*update_list3, *update_list4, gr.update(value=-1)
def re_edit():
global BBOX_MAX_NUM
update_list = []
for i in range(BBOX_MAX_NUM):
update_list.extend([gr.update(value=0.4), gr.update(value=0.4), gr.update(value=0.2),
gr.update(value=0.2)])
return *update_list, \
gr.Image(
value=create_canvas(),
label='Rect Position', elem_id='MD-bbox-rect-t2i',
show_label=False, visible=True
), \
gr.Slider(value=512), gr.Slider(value=512)
def resize_w(w, img):
return cv2.resize(img, (w, img.shape[0]))
def resize_h(h, img):
return cv2.resize(img, (img.shape[1], h))
def create_canvas(w=512, h=512, c=3, line=5):
image = np.full((h, w, c), 200, dtype=np.uint8)
for i in range(h):
if i % (w // line) == 0:
image[i, :, :] = 150
for j in range(w):
if j % (w // line) == 0:
image[:, j, :] = 150
image[h // 2 - 8:h // 2 + 8, w // 2 - 8:w // 2 + 8, :] = [200, 0, 0]
return image
def canny(img):
low_threshold = 64
high_threshold = 100
img = cv2.Canny(img, low_threshold, high_threshold)
img = img[:, :, None]
img = np.concatenate([img, img, img], axis=2)
return Image.fromarray(img)
def judge_overlap(coord_list1, coord_list2):
judge = coord_list1[0] < coord_list2[2] and coord_list1[2] > coord_list2[0] \
and coord_list1[1] < coord_list2[3] and coord_list1[3] > coord_list2[1]
return judge
def parse_render_list(box_list, shape, box_num):
width = shape[0]
height = shape[1]
polygons = []
font_names = []
texts = []
valid_list = box_list[:box_num]
pos_list = box_list[box_num: 5 * box_num]
font_name_list = box_list[5 * box_num: 6 * box_num]
text_list = box_list[6 * box_num: 7 * box_num]
empty_flag = False
print(font_name_list, text_list)
for i, valid in enumerate(valid_list):
if valid:
pos = pos_list[i * 4: (i + 1) * 4]
top_left_x = int(pos[0] * width)
top_left_y = int(pos[1] * height)
w = int(pos[2] * width)
h = int(pos[3] * height)
font_name = str(font_name_list[i])
text = str(text_list[i])
if text == '':
empty_flag = True
text = 'JoyType'
if w <= 0 or h <= 0:
gr.Warning(f'Area of the box{i + 1} cannot be zero!')
return [], False
polygon = [
top_left_x,
top_left_y,
w, h
]
try:
assert font_name in font_list
font_name = font_name.split('-')[-1]
except Exception as e:
gr.Warning('Please choose a correct font!')
return [], False
polygons.append(polygon)
font_names.append(font_name.split('-')[-1])
texts.append(text)
if empty_flag:
gr.Warning('Null strings will be filled automatically!')
for i in range(len(polygons)):
for j in range(i + 1, len(polygons)):
if judge_overlap(
[polygons[i][0], polygons[i][1], polygons[i][0] + polygons[i][2], polygons[i][1] + polygons[i][3]],
[polygons[j][0], polygons[j][1], polygons[j][0] + polygons[j][2], polygons[j][1] + polygons[j][3]]
):
gr.Warning('Find overlapping boxes!')
return [], False
render_list = []
for i in range(len(polygons)):
text_dict = {}
text_dict['text'] = texts[i]
text_dict['polygon'] = polygons[i]
text_dict['font_name'] = font_names[i]
render_list.append(text_dict)
return render_list, True
def render_all_text(render_list, shape, threshold=512):
width = shape[0]
height = shape[1]
board = Image.new('RGB', (width, height), 'black')
for text_dict in render_list:
text = text_dict['text']
polygon = text_dict['polygon']
font_name = text_dict['font_name']
if len(text) > MAX_LENGTH:
text = text[:MAX_LENGTH]
gr.Warning(f'{text}... exceeds the maximum length {MAX_LENGTH} and has been cropped.')
w, h = polygon[2:]
vert = True if w < h else False
image4ratio = Image.new('RGB', (1024, 1024), 'black')
draw = ImageDraw.Draw(image4ratio)
try:
font = ImageFont.truetype(f'./font/{font_name}.ttf', encoding='utf-8', size=50)
except FileNotFoundError:
font = ImageFont.truetype(f'./font/{font_name}.otf', encoding='utf-8', size=50)
if not vert:
draw.text(xy=(0, 0), text=text, font=font, fill='white')
_, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font)
_th += 1
else:
_tw, y_c = 0, 0
for c in text:
draw.text(xy=(0, y_c), text=c, font=font, fill='white')
_l, _t, _r, _b = font.getbbox(c)
_tw = max(_tw, _r - _l)
y_c += _b
_th = y_c + 1
ratio = (_th * w) / (_tw * h)
text_img = image4ratio.crop((0, 0, _tw, _th))
x_offset, y_offset = 0, 0
if 0.8 <= ratio <= 1.2:
text_img = text_img.resize((w, h))
elif ratio < 0.75:
resize_h = int(_th * (w / _tw))
text_img = text_img.resize((w, resize_h))
y_offset = (h - resize_h) // 2
else:
resize_w = int(_tw * (h / _th))
text_img = text_img.resize((resize_w, h))
x_offset = (w - resize_w) // 2
board.paste(text_img, (polygon[0] + x_offset, polygon[1] + y_offset))
return board
def load_pipeline(model_name, scheduler_name):
controlnet_path = os.path.join(model_root, f'{match_dict["JoyType.v1.0"]}')
model_path = os.path.join(model_root, model_name)
scheduler_name = scheduler_name.lower()
if scheduler_name == 'pndm':
scheduler = PNDMScheduler.from_pretrained(scheduler_root, subfolder='pndm')
if scheduler_name == 'lms':
scheduler = LMSDiscreteScheduler.from_pretrained(scheduler_root, subfolder='lms')
if scheduler_name == 'euler':
scheduler = EulerDiscreteScheduler.from_pretrained(scheduler_root, subfolder='euler')
if scheduler_name == 'dpm':
scheduler = DPMSolverMultistepScheduler.from_pretrained(scheduler_root, subfolder='dpm')
if scheduler_name == 'ddim':
scheduler = DDIMScheduler.from_pretrained(scheduler_root, subfolder='ddim')
if scheduler_name == 'heun':
scheduler = HeunDiscreteScheduler.from_pretrained(scheduler_root, subfolder='heun')
if scheduler_name == 'euler-ancestral':
scheduler = EulerAncestralDiscreteScheduler.from_pretrained(scheduler_root, subfolder='euler-ancestral')
controlnet = ControlNetModel.from_pretrained(
controlnet_path,
subfolder='controlnet',
torch_dtype=torch.float32
)
pipeline = StableDiffusionControlNetPipeline.from_pretrained(
model_path,
scheduler=scheduler,
controlnet=controlnet,
torch_dtype=torch.float32,
).to(device)
return pipeline
def preprocess_prompt(prompt):
client = ZhipuAI(api_key=os.getenv('ZHIPU_API_KEY'))
response = client.chat.completions.create(
model="glm-4-0520",
messages=[
{
'role': 'system',
'content': '''
Stable Diffusion是一款利用深度学习的文生图模型,支持通过使用提示词来产生新的图像,描述要包含或省略的元素。
我在这里引入Stable Diffusion算法中的Prompt概念,又被称为提示符。这里的Prompt通常可以用来描述图像,
他由普通常见的单词构成,最好是可以在数据集来源站点找到的著名标签(比如Ddanbooru)。
下面我将说明Prompt的生出步骤,这里的Prompt主要用于描述人物。在Prompt的生成中,你需要通过提示词来描述 人物属性,主题,外表,情绪,衣服,姿势,视角,动作,背景。
用英语单词或短语甚至自然语言的标签来描述,并不局限于我给你的单词。然后将你想要的相似的提示词组合在一起,请使用英文半角,做分隔符,每个提示词不要带引号,并将这些按从最重要到最不重要的顺序 排列。
另外请您注意,永远在每个 Prompt的前面加上引号里的内容,
“(((best quality))),(((ultra detailed))),(((masterpiece))),illustration,” 这是高质量的标志。
人物属性中,1girl表示你生成了一个女孩,2girls表示生成了两个女孩,一次。另外再注意,Prompt中不能带有-和_。
可以有空格和自然语言,但不要太多,单词不能重复。只返回Prompt。
'''
},
{
'role': 'user',
'content': prompt
}
],
temperature=0.5,
max_tokens=2048,
top_p=1,
stream=False,
)
if response:
glm = []
glm_return_list = response.choices
for item in glm_return_list:
glm.append(item.message.content)
return {'flag': 1, 'data': glm}
else:
return {'flag': 0, 'data': {}}
def process(
num_samples,
a_prompt,
n_prompt,
conditioning_scale,
cfg_scale,
inference_steps,
seed,
usr_prompt,
rect_img,
base_model,
scheduler_name,
box_num,
*box_list
):
if usr_prompt == '':
gr.Warning('Must input a prompt!')
return None, gr.Markdown('error')
if seed == -1:
seed = random.randint(0, 2147483647)
seed_everything(seed)
# Support Chinese Input
if usr_prompt in chn_example_dict.keys():
usr_prompt = chn_example_dict[usr_prompt]
else:
for ch in usr_prompt:
if '\u4e00' <= ch <= '\u9fff':
data = preprocess_prompt(usr_prompt)
if data['flag'] == 1:
usr_prompt = data['data'][0][1: -1]
else:
gr.Warning('Something went wrong while translating your prompt, please try again.')
return None, gr.Markdown('error')
break
shape = (rect_img.shape[1], rect_img.shape[0])
render_list, flag = parse_render_list(box_list, shape, box_num)
if flag:
render_img = render_all_text(render_list, shape)
else:
return None, gr.Markdown('error')
model_name = match_dict[base_model]
render_img = canny(np.array(render_img))
w, h = render_img.size
global pipeline, pre_pipeline
if pre_pipeline != model_name or pipeline is None:
pre_pipeline = model_name
pipeline = load_pipeline(model_name, scheduler_name)
batch_render_img = [render_img for _ in range(num_samples)]
batch_prompt = [f'{usr_prompt}, {a_prompt}' for _ in range(num_samples)]
batch_n_prompt = [n_prompt for _ in range(num_samples)]
images = pipeline(
batch_prompt,
negative_prompt=batch_n_prompt,
image=batch_render_img,
controlnet_conditioning_scale=float(conditioning_scale),
guidance_scale=float(cfg_scale),
width=w,
height=h,
num_inference_steps=int(inference_steps),
).images
return images, gr.Markdown(f'{seed}, {usr_prompt}, {box_list}')
def draw_example(box_list, color, id):
board = Image.fromarray(create_canvas())
w, h = board.size
draw = ImageDraw.Draw(board, mode='RGBA')
visible = box_list[:BBOX_MAX_NUM]
pos = box_list[BBOX_MAX_NUM: 5 * BBOX_MAX_NUM]
font = box_list[5 * BBOX_MAX_NUM: 6 * BBOX_MAX_NUM]
text = box_list[6 * BBOX_MAX_NUM:]
info = {
'visible': list(visible),
'pos': list(pos),
'font': list(font),
'text': list(text)
}
with open(f'templates/{id}.json', 'w') as f:
json.dump(info, f)
for i in range(BBOX_MAX_NUM):
if visible[i] is True:
polygon = pos[i * 4: (i + 1) * 4]
print(polygon)
left = w * polygon[0]
top = h * polygon[1]
right = left + w * polygon[2]
bottom = top + h * polygon[3]
draw.rectangle([left, top, right, bottom], outline=color[i][0], fill=color[i][1], width=3)
board.save(f'./examples/{id}.png')
if __name__ == '__main__':
pass