import json import os import cv2 import random import numpy as np import gradio as gr import torch from zhipuai import ZhipuAI from pytorch_lightning import seed_everything from pprint import pprint from PIL import Image, ImageDraw, ImageFont from diffusers import ( ControlNetModel, StableDiffusionControlNetPipeline, ) from diffusers import ( DDIMScheduler, PNDMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler, HeunDiscreteScheduler ) from controlnet_aux import ( PidiNetDetector, HEDdetector ) BBOX_MAX_NUM = 8 BBOX_INI_NUM = 0 MAX_LENGTH = 20 device = 'cuda' pipeline = None pre_pipeline = None model_root = os.getenv('REPO_ROOT') scheduler_root = f'{model_root}/Scheduler' model_list =[ 'JoyType.v1.0', 'RevAnimated-animation-动漫', 'GhostMix-animation-动漫', 'rpg.v5-fantasy_realism-奇幻写实', 'midjourneyPapercut-origami-折纸版画', 'dvarchExterior-architecture-建筑', 'awpainting.v13-portrait-人物肖像' ] chn_example_dict = { '漂亮的风景照,很多山峰,清澈的湖水': 'beautiful landscape, many peaks, clear lake', '画有玫瑰的卡片,明亮的背景': 'a card with roses, bright background', '一张关于健康教育的卡片,上面有一些文字,有一些食物图标,背景里有一些水果喝饮料的图标,且背景是模糊的': \ 'a card for health education, with some writings on it, ' 'food icons on the card, some fruits and drinking in the background, blur background ' } match_dict = { 'JoyType.v1.0': 'JoyType-v1-1M', 'RevAnimated-animation-动漫': 'rev-animated-v1-2-2', 'GhostMix-animation-动漫': 'GhostMix_V2.0', 'rpg.v5-fantasy_realism-奇幻写实': 'rpg_v5', 'midjourneyPapercut-origami-折纸版画': 'midjourneyPapercut_v1', 'dvarchExterior-architecture-建筑': 'dvarchExterior', 'awpainting.v13-portrait-人物肖像': 'awpainting_v13' } font_list = [ 'CHN-华文行楷', 'CHN-华文新魏', 'CHN-清松手写体', 'CHN-巴蜀墨迹', 'CHN-雷盖体', 'CHN-演示夏行楷', 'CHN-鸿雷板书简体', 'CHN-斑马字类', 'CHN-青柳隶书', 'CHN-辰宇落雁体', 'CHN-宅家麦克笔', 'ENG-Playwrite', 'ENG-Okesip', 'ENG-Shrikhand', 'ENG-Nextstep', 'ENG-Filthyrich', 'ENG-BebasNeue', 'ENG-Gloock', 'ENG-Lemon', 'RUS-Automatons', 'RUS-MKyrill', 'RUS-Alice', 'RUS-Caveat', 'KOR-ChosunGs', 'KOR-Dongle', 'KOR-GodoMaum', 'KOR-UnDotum', 'JPN-GlTsukiji', 'JPN-Aoyagireisyosimo', 'JPN-KouzanMouhitu', 'JPN-Otomanopee' ] def change_settings(base_model): if base_model == model_list[0]: return gr.update(value=20), gr.update(value=7.5), gr.update(value='PNDM') elif base_model == model_list[1]: return gr.update(value=30), gr.update(value=8.5), gr.update(value='Euler') elif base_model == model_list[2]: return gr.update(value=32), gr.update(value=8.5), gr.update(value='Euler') elif base_model == model_list[3]: return gr.update(value=20), gr.update(value=7.5), gr.update(value='DPM') elif base_model == model_list[4]: return gr.update(value=25), gr.update(value=6.5), gr.update(value='Euler') elif base_model == model_list[5]: return gr.update(value=25), gr.update(value=8.5), gr.update(value='Euler') elif base_model == model_list[6]: return gr.update(value=25), gr.update(value=7), gr.update(value='DPM') else: pass def update_box_num(choice): update_list_1 = [] # checkbox update_list_2 = [] # font update_list_3 = [] # text update_list_4 = [] # bounding box for i in range(BBOX_MAX_NUM): if i < choice: update_list_1.append(gr.update(value=True)) update_list_2.append(gr.update(visible=True)) update_list_3.append(gr.update(visible=True)) update_list_4.extend([gr.update(visible=False) for _ in range(4)]) else: update_list_1.append(gr.update(value=False)) update_list_2.append(gr.update(visible=False, value='CHN-华文行楷')) update_list_3.append(gr.update(visible=False, value='')) update_list_4.extend([ gr.update(visible=False, value=0.4), gr.update(visible=False, value=0.4), gr.update(visible=False, value=0.2), gr.update(visible=False, value=0.2) ]) return *update_list_1, *update_list_2, *update_list_3, *update_list_4 def load_box_list(example_id, choice): with open(f'templates/{example_id}.json', 'r') as f: info = json.load(f) update_list1 = [] update_list2 = [] update_list3 = [] update_list4 = [] for i in range(BBOX_MAX_NUM): visible = info['visible'][i] pos = info['pos'][i * 4: (i + 1) * 4] update_list1.append(gr.update(value=visible)) update_list2.append(gr.update(value=info['font'][i], visible=visible)) update_list3.append(gr.update(value=info['text'][i], visible=visible)) update_list4.extend([ gr.update(value=pos[0]), gr.update(value=pos[1]), gr.update(value=pos[2]), gr.update(value=pos[3]) ]) return *update_list1, *update_list2, \ *update_list3, *update_list4, gr.update(value=-1) def re_edit(): global BBOX_MAX_NUM update_list = [] for i in range(BBOX_MAX_NUM): update_list.extend([gr.update(value=0.4), gr.update(value=0.4), gr.update(value=0.2), gr.update(value=0.2)]) return *update_list, \ gr.Image( value=create_canvas(), label='Rect Position', elem_id='MD-bbox-rect-t2i', show_label=False, visible=True ), \ gr.Slider(value=512), gr.Slider(value=512) def resize_w(w, img): return cv2.resize(img, (w, img.shape[0])) def resize_h(h, img): return cv2.resize(img, (img.shape[1], h)) def create_canvas(w=512, h=512, c=3, line=5): image = np.full((h, w, c), 200, dtype=np.uint8) for i in range(h): if i % (w // line) == 0: image[i, :, :] = 150 for j in range(w): if j % (w // line) == 0: image[:, j, :] = 150 image[h // 2 - 8:h // 2 + 8, w // 2 - 8:w // 2 + 8, :] = [200, 0, 0] return image def canny(img): low_threshold = 64 high_threshold = 100 img = cv2.Canny(img, low_threshold, high_threshold) img = img[:, :, None] img = np.concatenate([img, img, img], axis=2) return Image.fromarray(img) def judge_overlap(coord_list1, coord_list2): judge = coord_list1[0] < coord_list2[2] and coord_list1[2] > coord_list2[0] \ and coord_list1[1] < coord_list2[3] and coord_list1[3] > coord_list2[1] return judge def parse_render_list(box_list, shape, box_num): width = shape[0] height = shape[1] polygons = [] font_names = [] texts = [] valid_list = box_list[:box_num] pos_list = box_list[box_num: 5 * box_num] font_name_list = box_list[5 * box_num: 6 * box_num] text_list = box_list[6 * box_num: 7 * box_num] empty_flag = False print(font_name_list, text_list) for i, valid in enumerate(valid_list): if valid: pos = pos_list[i * 4: (i + 1) * 4] top_left_x = int(pos[0] * width) top_left_y = int(pos[1] * height) w = int(pos[2] * width) h = int(pos[3] * height) font_name = str(font_name_list[i]) text = str(text_list[i]) if text == '': empty_flag = True text = 'JoyType' if w <= 0 or h <= 0: gr.Warning(f'Area of the box{i + 1} cannot be zero!') return [], False polygon = [ top_left_x, top_left_y, w, h ] try: assert font_name in font_list font_name = font_name.split('-')[-1] except Exception as e: gr.Warning('Please choose a correct font!') return [], False polygons.append(polygon) font_names.append(font_name.split('-')[-1]) texts.append(text) if empty_flag: gr.Warning('Null strings will be filled automatically!') for i in range(len(polygons)): for j in range(i + 1, len(polygons)): if judge_overlap( [polygons[i][0], polygons[i][1], polygons[i][0] + polygons[i][2], polygons[i][1] + polygons[i][3]], [polygons[j][0], polygons[j][1], polygons[j][0] + polygons[j][2], polygons[j][1] + polygons[j][3]] ): gr.Warning('Find overlapping boxes!') return [], False render_list = [] for i in range(len(polygons)): text_dict = {} text_dict['text'] = texts[i] text_dict['polygon'] = polygons[i] text_dict['font_name'] = font_names[i] render_list.append(text_dict) return render_list, True def render_all_text(render_list, shape, threshold=512): width = shape[0] height = shape[1] board = Image.new('RGB', (width, height), 'black') for text_dict in render_list: text = text_dict['text'] polygon = text_dict['polygon'] font_name = text_dict['font_name'] if len(text) > MAX_LENGTH: text = text[:MAX_LENGTH] gr.Warning(f'{text}... exceeds the maximum length {MAX_LENGTH} and has been cropped.') w, h = polygon[2:] vert = True if w < h else False image4ratio = Image.new('RGB', (1024, 1024), 'black') draw = ImageDraw.Draw(image4ratio) try: font = ImageFont.truetype(f'./font/{font_name}.ttf', encoding='utf-8', size=50) except FileNotFoundError: font = ImageFont.truetype(f'./font/{font_name}.otf', encoding='utf-8', size=50) if not vert: draw.text(xy=(0, 0), text=text, font=font, fill='white') _, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font) _th += 1 else: _tw, y_c = 0, 0 for c in text: draw.text(xy=(0, y_c), text=c, font=font, fill='white') _l, _t, _r, _b = font.getbbox(c) _tw = max(_tw, _r - _l) y_c += _b _th = y_c + 1 ratio = (_th * w) / (_tw * h) text_img = image4ratio.crop((0, 0, _tw, _th)) x_offset, y_offset = 0, 0 if 0.8 <= ratio <= 1.2: text_img = text_img.resize((w, h)) elif ratio < 0.75: resize_h = int(_th * (w / _tw)) text_img = text_img.resize((w, resize_h)) y_offset = (h - resize_h) // 2 else: resize_w = int(_tw * (h / _th)) text_img = text_img.resize((resize_w, h)) x_offset = (w - resize_w) // 2 board.paste(text_img, (polygon[0] + x_offset, polygon[1] + y_offset)) return board def load_pipeline(model_name, scheduler_name): controlnet_path = os.path.join(model_root, f'{match_dict["JoyType.v1.0"]}') model_path = os.path.join(model_root, model_name) scheduler_name = scheduler_name.lower() if scheduler_name == 'pndm': scheduler = PNDMScheduler.from_pretrained(scheduler_root, subfolder='pndm') if scheduler_name == 'lms': scheduler = LMSDiscreteScheduler.from_pretrained(scheduler_root, subfolder='lms') if scheduler_name == 'euler': scheduler = EulerDiscreteScheduler.from_pretrained(scheduler_root, subfolder='euler') if scheduler_name == 'dpm': scheduler = DPMSolverMultistepScheduler.from_pretrained(scheduler_root, subfolder='dpm') if scheduler_name == 'ddim': scheduler = DDIMScheduler.from_pretrained(scheduler_root, subfolder='ddim') if scheduler_name == 'heun': scheduler = HeunDiscreteScheduler.from_pretrained(scheduler_root, subfolder='heun') if scheduler_name == 'euler-ancestral': scheduler = EulerAncestralDiscreteScheduler.from_pretrained(scheduler_root, subfolder='euler-ancestral') controlnet = ControlNetModel.from_pretrained( controlnet_path, subfolder='controlnet', torch_dtype=torch.float32 ) pipeline = StableDiffusionControlNetPipeline.from_pretrained( model_path, scheduler=scheduler, controlnet=controlnet, torch_dtype=torch.float32, ).to(device) return pipeline def preprocess_prompt(prompt): client = ZhipuAI(api_key=os.getenv('ZHIPU_API_KEY')) response = client.chat.completions.create( model="glm-4-0520", messages=[ { 'role': 'system', 'content': ''' Stable Diffusion是一款利用深度学习的文生图模型,支持通过使用提示词来产生新的图像,描述要包含或省略的元素。 我在这里引入Stable Diffusion算法中的Prompt概念,又被称为提示符。这里的Prompt通常可以用来描述图像, 他由普通常见的单词构成,最好是可以在数据集来源站点找到的著名标签(比如Ddanbooru)。 下面我将说明Prompt的生出步骤,这里的Prompt主要用于描述人物。在Prompt的生成中,你需要通过提示词来描述 人物属性,主题,外表,情绪,衣服,姿势,视角,动作,背景。 用英语单词或短语甚至自然语言的标签来描述,并不局限于我给你的单词。然后将你想要的相似的提示词组合在一起,请使用英文半角,做分隔符,每个提示词不要带引号,并将这些按从最重要到最不重要的顺序 排列。 另外请您注意,永远在每个 Prompt的前面加上引号里的内容, “(((best quality))),(((ultra detailed))),(((masterpiece))),illustration,” 这是高质量的标志。 人物属性中,1girl表示你生成了一个女孩,2girls表示生成了两个女孩,一次。另外再注意,Prompt中不能带有-和_。 可以有空格和自然语言,但不要太多,单词不能重复。只返回Prompt。 ''' }, { 'role': 'user', 'content': prompt } ], temperature=0.5, max_tokens=2048, top_p=1, stream=False, ) if response: glm = [] glm_return_list = response.choices for item in glm_return_list: glm.append(item.message.content) return {'flag': 1, 'data': glm} else: return {'flag': 0, 'data': {}} def process( num_samples, a_prompt, n_prompt, conditioning_scale, cfg_scale, inference_steps, seed, usr_prompt, rect_img, base_model, scheduler_name, box_num, *box_list ): if usr_prompt == '': gr.Warning('Must input a prompt!') return None, gr.Markdown('error') if seed == -1: seed = random.randint(0, 2147483647) seed_everything(seed) # Support Chinese Input if usr_prompt in chn_example_dict.keys(): usr_prompt = chn_example_dict[usr_prompt] else: for ch in usr_prompt: if '\u4e00' <= ch <= '\u9fff': data = preprocess_prompt(usr_prompt) if data['flag'] == 1: usr_prompt = data['data'][0][1: -1] else: gr.Warning('Something went wrong while translating your prompt, please try again.') return None, gr.Markdown('error') break shape = (rect_img.shape[1], rect_img.shape[0]) render_list, flag = parse_render_list(box_list, shape, box_num) if flag: render_img = render_all_text(render_list, shape) else: return None, gr.Markdown('error') model_name = match_dict[base_model] render_img = canny(np.array(render_img)) w, h = render_img.size global pipeline, pre_pipeline if pre_pipeline != model_name or pipeline is None: pre_pipeline = model_name pipeline = load_pipeline(model_name, scheduler_name) batch_render_img = [render_img for _ in range(num_samples)] batch_prompt = [f'{usr_prompt}, {a_prompt}' for _ in range(num_samples)] batch_n_prompt = [n_prompt for _ in range(num_samples)] images = pipeline( batch_prompt, negative_prompt=batch_n_prompt, image=batch_render_img, controlnet_conditioning_scale=float(conditioning_scale), guidance_scale=float(cfg_scale), width=w, height=h, num_inference_steps=int(inference_steps), ).images return images, gr.Markdown(f'{seed}, {usr_prompt}, {box_list}') def draw_example(box_list, color, id): board = Image.fromarray(create_canvas()) w, h = board.size draw = ImageDraw.Draw(board, mode='RGBA') visible = box_list[:BBOX_MAX_NUM] pos = box_list[BBOX_MAX_NUM: 5 * BBOX_MAX_NUM] font = box_list[5 * BBOX_MAX_NUM: 6 * BBOX_MAX_NUM] text = box_list[6 * BBOX_MAX_NUM:] info = { 'visible': list(visible), 'pos': list(pos), 'font': list(font), 'text': list(text) } with open(f'templates/{id}.json', 'w') as f: json.dump(info, f) for i in range(BBOX_MAX_NUM): if visible[i] is True: polygon = pos[i * 4: (i + 1) * 4] print(polygon) left = w * polygon[0] top = h * polygon[1] right = left + w * polygon[2] bottom = top + h * polygon[3] draw.rectangle([left, top, right, bottom], outline=color[i][0], fill=color[i][1], width=3) board.save(f'./examples/{id}.png') if __name__ == '__main__': pass