|
import gradio as gr |
|
import numpy as np |
|
import random |
|
import torch |
|
|
|
import io, json |
|
from PIL import Image |
|
import os.path |
|
from weight_fusion import compose_concepts |
|
from regionally_controlable_sampling import sample_image, build_model, prepare_text |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
power_device = "GPU" if torch.cuda.is_available() else "CPU" |
|
|
|
MAX_SEED = 100_000 |
|
|
|
def generate(region1_concept, |
|
region2_concept, |
|
prompt, |
|
pose_image_name, |
|
region1_prompt, |
|
region2_prompt, |
|
negative_prompt, |
|
region_neg_prompt, |
|
seed, |
|
randomize_seed, |
|
sketch_adaptor_weight, |
|
keypose_adaptor_weight |
|
): |
|
|
|
if region1_concept==region2_concept: |
|
raise gr.Error("Please choose two different characters for merging weights.") |
|
if len(pose_image_name)==0: |
|
raise gr.Error("Please select one spatial condition!") |
|
if len(region1_prompt)==0 or len(region1_prompt)==0: |
|
raise gr.Error("Your regional prompt cannot be empty.") |
|
if len(prompt)==0: |
|
raise gr.Error("Your global prompt cannot be empty.") |
|
|
|
if randomize_seed: |
|
seed = random.randint(0, MAX_SEED) |
|
|
|
region1_concept, region2_concept = region1_concept.lower(), region2_concept.lower() |
|
pretrained_model = merge(region1_concept, region2_concept) |
|
|
|
with open('multi-concept/pose_data/pose.json') as f: |
|
d = json.load(f) |
|
|
|
pose_image = {os.path.basename(obj['img_dir']):obj for obj in d}[pose_image_name] |
|
|
|
print(pose_image) |
|
keypose_condition = pose_image['img_dir'] |
|
region1 = pose_image['region1'] |
|
region2 = pose_image['region2'] |
|
|
|
region_pos_prompt = "high resolution, best quality, highly detailed, sharp focus, expressive, 8k uhd, detailed, sophisticated" |
|
region1_prompt = f'<{region1_concept}1> <{region1_concept}2>, {region1_prompt}, {region_pos_prompt}' |
|
region2_prompt = f'<{region2_concept}1> <{region2_concept}2>, {region2_prompt}, {region_pos_prompt}' |
|
prompt_rewrite=f"{region1_prompt}-*-{region_neg_prompt}-*-{region1}|{region2_prompt}-*-{region_neg_prompt}-*-{region2}" |
|
print(prompt_rewrite) |
|
|
|
result = infer(pretrained_model, |
|
prompt, |
|
prompt_rewrite, |
|
negative_prompt, |
|
seed, |
|
keypose_condition, |
|
keypose_adaptor_weight, |
|
|
|
|
|
) |
|
|
|
return result |
|
|
|
def merge(concept1, concept2): |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
c1, c2 = sorted([concept1, concept2]) |
|
assert c1!=c2 |
|
merge_name = c1+'_'+c2 |
|
|
|
save_path = f'experiments/multi-concept/{merge_name}' |
|
|
|
if os.path.isdir(save_path): |
|
print(f'{save_path} already exists. Collecting merged weights from existing weights...') |
|
|
|
else: |
|
os.makedirs(save_path) |
|
json_path = os.path.join(save_path,'merge_config.json') |
|
alpha = 1.8 |
|
data = [ |
|
{ |
|
"lora_path": f"experiments/single-concept/{c1}/models/edlora_model-latest.pth", |
|
"unet_alpha": alpha, |
|
"text_encoder_alpha": alpha, |
|
"concept_name": f"<{c1}1> <{c1}2>" |
|
}, |
|
{ |
|
"lora_path": f"experiments/single-concept/{c2}/models/edlora_model-latest.pth", |
|
"unet_alpha": alpha, |
|
"text_encoder_alpha": alpha, |
|
"concept_name": f"<{c2}1> <{c2}2>" |
|
} |
|
] |
|
with io.open(json_path,'w',encoding='utf8') as outfile: |
|
json.dump(data, outfile, indent = 4, ensure_ascii=False) |
|
|
|
compose_concepts( |
|
concept_cfg=json_path, |
|
optimize_textenc_iters=500, |
|
optimize_unet_iters=50, |
|
pretrained_model_path="nitrosocke/mo-di-diffusion", |
|
save_path=save_path, |
|
suffix='base', |
|
device=device, |
|
) |
|
print(f'Merged weight for {c1}+{c2} saved in {save_path}!\n\n') |
|
|
|
modelbase_path = os.path.join(save_path,'combined_model_base') |
|
assert os.path.isdir(modelbase_path) |
|
|
|
|
|
return modelbase_path |
|
|
|
def infer(pretrained_model, |
|
prompt, |
|
prompt_rewrite, |
|
negative_prompt='', |
|
seed=16141, |
|
keypose_condition=None, |
|
keypose_adaptor_weight=1.0, |
|
sketch_condition=None, |
|
sketch_adaptor_weight=0.0, |
|
region_sketch_adaptor_weight='', |
|
region_keypose_adaptor_weight='' |
|
): |
|
|
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') |
|
pipe = build_model(pretrained_model, device) |
|
|
|
if sketch_condition is not None and os.path.exists(sketch_condition): |
|
sketch_condition = Image.open(sketch_condition).convert('L') |
|
width_sketch, height_sketch = sketch_condition.size |
|
print('use sketch condition') |
|
else: |
|
sketch_condition, width_sketch, height_sketch = None, 0, 0 |
|
print('skip sketch condition') |
|
|
|
if keypose_condition is not None and os.path.exists(keypose_condition): |
|
keypose_condition = Image.open(keypose_condition).convert('RGB') |
|
width_pose, height_pose = keypose_condition.size |
|
print('use pose condition') |
|
else: |
|
keypose_condition, width_pose, height_pose = None, 0, 0 |
|
print('skip pose condition') |
|
|
|
if width_sketch != 0 and width_pose != 0: |
|
assert width_sketch == width_pose and height_sketch == height_pose, 'conditions should be same size' |
|
width, height = max(width_pose, width_sketch), max(height_pose, height_sketch) |
|
kwargs = { |
|
'sketch_condition': sketch_condition, |
|
'keypose_condition': keypose_condition, |
|
'height': height, |
|
'width': width, |
|
} |
|
|
|
prompts = [prompt] |
|
prompts_rewrite = [prompt_rewrite] |
|
input_prompt = [prepare_text(p, p_w, height, width) for p, p_w in zip(prompts, prompts_rewrite)] |
|
save_prompt = input_prompt[0][0] |
|
print(save_prompt) |
|
|
|
image = sample_image( |
|
pipe, |
|
input_prompt=input_prompt, |
|
input_neg_prompt=[negative_prompt] * len(input_prompt), |
|
generator=torch.Generator(device).manual_seed(seed), |
|
sketch_adaptor_weight=sketch_adaptor_weight, |
|
region_sketch_adaptor_weight=region_sketch_adaptor_weight, |
|
keypose_adaptor_weight=keypose_adaptor_weight, |
|
region_keypose_adaptor_weight=region_keypose_adaptor_weight, |
|
**kwargs) |
|
|
|
return image[0] |
|
|
|
def on_select(evt: gr.SelectData): |
|
return evt.value['image']['orig_name'] |
|
|
|
examples_context = [ |
|
'walking at Stanford university campus', |
|
'in front of a castle', |
|
'in the forest', |
|
'in the style of cyberpunk' |
|
] |
|
|
|
examples_region1 = ['wearing a red hat'] |
|
examples_region2 = ['smiling, wearing a blue shirt'] |
|
|
|
with open('multi-concept/pose_data/pose.json') as f: |
|
d = json.load(f) |
|
pose_image_list = [(obj['img_id'],obj['img_dir']) for obj in d] |
|
|
|
css=""" |
|
#col-container { |
|
margin: 0 auto; |
|
max-width: 600px; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.Markdown(f""" |
|
# Orthogonal Adaptation |
|
Describe your world with a **πͺ text prompt (global and local)** and choose two characters to merge. |
|
Select their **π― poses (spatial conditions)** for regionally controllable sampling to generate a unique image using our model. |
|
Let your creativity run wild! (Currently running on : {power_device} ) |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(elem_id="col-container"): |
|
|
|
|
|
|
|
|
|
|
|
with gr.Tab('πͺ Global and Region prompts'): |
|
prompt = gr.Text( |
|
label="ContextPrompt", |
|
show_label=False, |
|
max_lines=1, |
|
placeholder="Enter your global context prompt", |
|
container=False, |
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
|
concept_list = ["Elsa", "Moana", "Woody", "Rapunzel", "Elastigirl"] |
|
region1_concept = gr.Dropdown( |
|
concept_list, |
|
label="Character 1", |
|
info="Will add more characters later!" |
|
) |
|
region2_concept = gr.Dropdown( |
|
concept_list, |
|
label="Character 2", |
|
info="Will add more characters later!" |
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
|
region1_prompt = gr.Textbox( |
|
label="Region1 Prompt", |
|
show_label=False, |
|
max_lines=2, |
|
placeholder="Enter your regional prompt for character 1", |
|
container=False, |
|
) |
|
|
|
region2_prompt = gr.Textbox( |
|
label="Region2 Prompt", |
|
show_label=False, |
|
max_lines=2, |
|
placeholder="Enter your regional prompt for character 2", |
|
container=False, |
|
) |
|
|
|
|
|
gr.Examples( |
|
label = 'Global Prompt example', |
|
examples = examples_context, |
|
inputs = [prompt] |
|
) |
|
|
|
with gr.Row(): |
|
gr.Examples( |
|
label = 'Region1 Prompt example', |
|
examples = examples_region1, |
|
inputs = [region1_prompt] |
|
) |
|
|
|
gr.Examples( |
|
label = 'Region2 Prompt example', |
|
examples = [examples_region2], |
|
inputs = [region2_prompt] |
|
) |
|
|
|
|
|
|
|
|
|
|
|
with gr.Tab('π― Spatial Condition '): |
|
gallery = gr.Gallery(label = "Select pose for characters", |
|
value = [obj[1]for obj in pose_image_list], |
|
elem_id = [obj[0]for obj in pose_image_list], |
|
interactive=False, show_download_button=False, |
|
preview=True, height = 400, object_fit="scale-down") |
|
|
|
pose_image_name = gr.Textbox(visible=False) |
|
gallery.select(on_select, None, pose_image_name) |
|
|
|
run_button = gr.Button("Run", scale=1) |
|
|
|
with gr.Accordion("Advanced Settings", open=False): |
|
|
|
negative_prompt = gr.Text( |
|
label="Context Negative prompt", |
|
max_lines=1, |
|
value = 'saturated, cropped, worst quality, low quality', |
|
visible=False, |
|
) |
|
|
|
region_neg_prompt = gr.Text( |
|
label="Regional Negative prompt", |
|
max_lines=1, |
|
value = 'shirtless, nudity, saturated, cropped, worst quality, low quality', |
|
visible=False, |
|
) |
|
|
|
seed = gr.Slider( |
|
label="Seed", |
|
minimum=0, |
|
maximum=MAX_SEED, |
|
step=1, |
|
value=0, |
|
) |
|
|
|
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) |
|
|
|
with gr.Row(): |
|
|
|
sketch_adaptor_weight = gr.Slider( |
|
label="Sketch Adapter Weight", |
|
minimum = 0, |
|
maximum = 1, |
|
step=0.01, |
|
value=0, |
|
) |
|
|
|
keypose_adaptor_weight = gr.Slider( |
|
label="Keypose Adapter Weight", |
|
minimum = 0.1, |
|
maximum = 1, |
|
step= 0.01, |
|
value=1.0, |
|
) |
|
|
|
with gr.Column(): |
|
result = gr.Image(label="Result", show_label=False) |
|
|
|
gr.Markdown(f""" |
|
*Image generation may take longer for the first time you use a new combination of characters. <br /> |
|
This is because the model needs to load weights for each concept involved.* |
|
""") |
|
|
|
run_button.click( |
|
fn = generate, |
|
inputs = [region1_concept, |
|
region2_concept, |
|
prompt, |
|
pose_image_name, |
|
region1_prompt, |
|
region2_prompt, |
|
negative_prompt, |
|
region_neg_prompt, |
|
seed, |
|
randomize_seed, |
|
|
|
|
|
sketch_adaptor_weight, |
|
keypose_adaptor_weight |
|
], |
|
outputs = [result] |
|
) |
|
|
|
demo.queue().launch(share=True) |