import gradio as gr import numpy as np import random import torch import io, json from PIL import Image import os.path from weight_fusion import compose_concepts from regionally_controlable_sampling import sample_image, build_model, prepare_text device = "cuda" if torch.cuda.is_available() else "cpu" power_device = "GPU" if torch.cuda.is_available() else "CPU" MAX_SEED = 100_000 def generate(region1_concept, region2_concept, prompt, pose_image_name, region1_prompt, region2_prompt, negative_prompt, region_neg_prompt, seed, randomize_seed, sketch_adaptor_weight, keypose_adaptor_weight ): if region1_concept==region2_concept: raise gr.Error("Please choose two different characters for merging weights.") if len(pose_image_name)==0: raise gr.Error("Please select one spatial condition!") if len(region1_prompt)==0 or len(region1_prompt)==0: raise gr.Error("Your regional prompt cannot be empty.") if len(prompt)==0: raise gr.Error("Your global prompt cannot be empty.") if randomize_seed: seed = random.randint(0, MAX_SEED) region1_concept, region2_concept = region1_concept.lower(), region2_concept.lower() pretrained_model = merge(region1_concept, region2_concept) with open('multi-concept/pose_data/pose.json') as f: d = json.load(f) pose_image = {os.path.basename(obj['img_dir']):obj for obj in d}[pose_image_name] # pose_image = {obj.pop('pose_id'):obj for obj in d}[int(pose_image_id)] print(pose_image) keypose_condition = pose_image['img_dir'] region1 = pose_image['region1'] region2 = pose_image['region2'] region1_prompt = f'[<{region1_concept}1> <{region1_concept}2>, {region1_prompt}]' region2_prompt = f'[<{region2_concept}1> <{region2_concept}2>, {region2_prompt}]' prompt_rewrite=f"{region1_prompt}-*-{region_neg_prompt}-*-{region1}|{region2_prompt}-*-{region_neg_prompt}-*-{region2}" print(prompt_rewrite) result = infer(pretrained_model, prompt, prompt_rewrite, negative_prompt, seed, keypose_condition, keypose_adaptor_weight, # sketch_condition, # sketch_adaptor_weight, ) return result def merge(concept1, concept2): device = "cuda" if torch.cuda.is_available() else "cpu" c1, c2 = sorted([concept1, concept2]) assert c1!=c2 merge_name = c1+'_'+c2 save_path = f'experiments/multi-concept/{merge_name}' if os.path.isdir(save_path): print(f'{save_path} already exists. Collecting merged weights from existing weights...') else: os.makedirs(save_path) json_path = os.path.join(save_path,'merge_config.json') alpha = 1.8 data = [ { "lora_path": f"experiments/single-concept/{c1}/models/edlora_model-latest.pth", "unet_alpha": alpha, "text_encoder_alpha": alpha, "concept_name": f"<{c1}1> <{c1}2>" }, { "lora_path": f"experiments/single-concept/{c2}/models/edlora_model-latest.pth", "unet_alpha": alpha, "text_encoder_alpha": alpha, "concept_name": f"<{c2}1> <{c2}2>" } ] with io.open(json_path,'w',encoding='utf8') as outfile: json.dump(data, outfile, indent = 4, ensure_ascii=False) compose_concepts( concept_cfg=json_path, optimize_textenc_iters=500, optimize_unet_iters=50, pretrained_model_path="nitrosocke/mo-di-diffusion", save_path=save_path, suffix='base', device=device, ) print(f'Merged weight for {c1}+{c2} saved in {save_path}!\n\n') modelbase_path = os.path.join(save_path,'combined_model_base') assert os.path.isdir(modelbase_path) return modelbase_path def infer(pretrained_model, prompt, prompt_rewrite, negative_prompt='', seed=16141, keypose_condition=None, keypose_adaptor_weight=1.0, sketch_condition=None, sketch_adaptor_weight=0.0, region_sketch_adaptor_weight='', region_keypose_adaptor_weight='' ): device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') pipe = build_model(pretrained_model, device) if sketch_condition is not None and os.path.exists(sketch_condition): sketch_condition = Image.open(sketch_condition).convert('L') width_sketch, height_sketch = sketch_condition.size print('use sketch condition') else: sketch_condition, width_sketch, height_sketch = None, 0, 0 print('skip sketch condition') if keypose_condition is not None and os.path.exists(keypose_condition): keypose_condition = Image.open(keypose_condition).convert('RGB') width_pose, height_pose = keypose_condition.size print('use pose condition') else: keypose_condition, width_pose, height_pose = None, 0, 0 print('skip pose condition') if width_sketch != 0 and width_pose != 0: assert width_sketch == width_pose and height_sketch == height_pose, 'conditions should be same size' width, height = max(width_pose, width_sketch), max(height_pose, height_sketch) kwargs = { 'sketch_condition': sketch_condition, 'keypose_condition': keypose_condition, 'height': height, 'width': width, } prompts = [prompt] prompts_rewrite = [prompt_rewrite] input_prompt = [prepare_text(p, p_w, height, width) for p, p_w in zip(prompts, prompts_rewrite)] save_prompt = input_prompt[0][0] print(save_prompt) image = sample_image( pipe, input_prompt=input_prompt, input_neg_prompt=[negative_prompt] * len(input_prompt), generator=torch.Generator(device).manual_seed(seed), sketch_adaptor_weight=sketch_adaptor_weight, region_sketch_adaptor_weight=region_sketch_adaptor_weight, keypose_adaptor_weight=keypose_adaptor_weight, region_keypose_adaptor_weight=region_keypose_adaptor_weight, **kwargs) return image[0] def on_select(evt: gr.SelectData): # SelectData is a subclass of EventData return evt.value['image']['orig_name'] examples_context = [ 'walking at Stanford university campus', 'in a castle', 'in the forest', 'in front of Eiffel tower' ] examples_region1 = ['wearing red hat, high resolution, best quality'] examples_region2 = ['smilling, wearing blue shirt, high resolution, best quality'] with open('multi-concept/pose_data/pose.json') as f: d = json.load(f) pose_image_list = [(obj['img_id'],obj['img_dir']) for obj in d] css=""" #col-container { margin: 0 auto; max-width: 600px; } """ with gr.Blocks(css=css) as demo: gr.Markdown(f""" # Orthogonal Adaptation Describe your world with a **🪄 text prompt (global and local)** and choose two characters to merge. Select their **👯 poses (spatial conditions)** for regionally controllable sampling to generate a unique image using our model. Let your creativity run wild! (Currently running on : {power_device} ) """) with gr.Row(): with gr.Column(elem_id="col-container"): # gr.Markdown(f""" # ### 🪄 Global and Region prompts # """) # with gr.Group(): with gr.Tab('🪄 Global and Region prompts'): prompt = gr.Text( label="ContextPrompt", show_label=False, max_lines=1, placeholder="Enter your global context prompt", container=False, ) with gr.Row(): region1_concept = gr.Dropdown( ["Elsa", "Moana", "Woody"], label="Character 1", info="Will add more characters later!" ) region2_concept = gr.Dropdown( ["Elsa", "Moana", "Woody"], label="Character 2", info="Will add more characters later!" ) with gr.Row(): region1_prompt = gr.Textbox( label="Region1 Prompt", show_label=False, max_lines=2, placeholder="Enter your regional prompt for character 1", container=False, ) region2_prompt = gr.Textbox( label="Region2 Prompt", show_label=False, max_lines=2, placeholder="Enter your regional prompt for character 2", container=False, ) gr.Examples( label = 'Global Prompt example', examples = examples_context, inputs = [prompt] ) with gr.Row(): gr.Examples( label = 'Region1 Prompt example', examples = examples_region1, inputs = [region1_prompt] ) gr.Examples( label = 'Region2 Prompt example', examples = [examples_region2], inputs = [region2_prompt] ) # gr.Markdown(f""" # ### 👯 Spatial Condition # """) # with gr.Group(): with gr.Tab('👯 Spatial Condition '): gallery = gr.Gallery(label = "Select pose for characters", value = [obj[1]for obj in pose_image_list], elem_id = [obj[0]for obj in pose_image_list], interactive=False, show_download_button=False, preview=True, height = 400, object_fit="scale-down") pose_image_name = gr.Textbox(visible=False) gallery.select(on_select, None, pose_image_name) run_button = gr.Button("Run", scale=1) with gr.Accordion("Advanced Settings", open=False): negative_prompt = gr.Text( label="Context Negative prompt", max_lines=1, value = 'saturated, cropped, worst quality, low quality', visible=False, ) region_neg_prompt = gr.Text( label="Regional Negative prompt", max_lines=1, value = 'shirtless, nudity, saturated, cropped, worst quality, low quality', visible=False, ) seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): sketch_adaptor_weight = gr.Slider( label="Sketch Adapter Weight", minimum = 0, maximum = 1, step=0.01, value=0, ) keypose_adaptor_weight = gr.Slider( label="Keypose Adapter Weight", minimum = 0, maximum = 1, step= 0.01, value=1.0, ) with gr.Column(): result = gr.Image(label="Result", show_label=False) gr.Markdown(f""" *Image generation may take longer for the first time you use a new combination of characters.
This is because the model needs to load weights for each concept involved.* """) run_button.click( fn = generate, inputs = [region1_concept, region2_concept, prompt, pose_image_name, region1_prompt, region2_prompt, negative_prompt, region_neg_prompt, seed, randomize_seed, # sketch_condition, # keypose_condition, sketch_adaptor_weight, keypose_adaptor_weight ], outputs = [result] ) demo.queue().launch(share=True)