|
import gradio as gr |
|
import numpy as np |
|
import random |
|
import torch |
|
|
|
import io, json |
|
from PIL import Image |
|
import os.path |
|
from weight_fusion import compose_concepts |
|
from regionally_controlable_sampling import sample_image, build_model, prepare_text |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
power_device = "GPU" if torch.cuda.is_available() else "CPU" |
|
|
|
MAX_SEED = 100_000 |
|
|
|
def generate(region1_concept, |
|
region2_concept, |
|
prompt, |
|
region1_prompt, |
|
region2_prompt, |
|
negative_prompt, |
|
region_neg_prompt, |
|
seed, |
|
randomize_seed, |
|
sketch_adaptor_weight, |
|
keypose_adaptor_weight |
|
): |
|
|
|
if randomize_seed: |
|
seed = random.randint(0, MAX_SEED) |
|
|
|
region1_concept, region2_concept = region1_concept.lower(), region2_concept.lower() |
|
pretrained_model = merge(region1_concept, region2_concept) |
|
|
|
keypose_condition = 'multi-concept/pose_data/two_apart.png' |
|
region1 = '[0, 0, 512, 290]' |
|
region2 = '[0, 650, 512, 910]' |
|
|
|
region1_prompt = f'[<{region1_concept}1> <{region1_concept}2>, {region1_prompt}]' |
|
region2_prompt = f'[<{region2_concept}1> <{region2_concept}2>, {region2_prompt}]' |
|
prompt_rewrite=f"{region1_prompt}-*-{region_neg_prompt}-*-{region1}|{region2_prompt}-*-{region_neg_prompt}-*-{region2}" |
|
|
|
result = infer(pretrained_model, |
|
prompt, |
|
prompt_rewrite, |
|
negative_prompt, |
|
seed, |
|
keypose_condition, |
|
keypose_adaptor_weight, |
|
|
|
|
|
) |
|
|
|
return result |
|
|
|
def merge(concept1, concept2): |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
c1, c2 = sorted([concept1, concept2]) |
|
assert c1!=c2 |
|
merge_name = c1+'_'+c2 |
|
|
|
save_path = f'experiments/multi-concept/{merge_name}' |
|
|
|
if os.path.isdir(save_path): |
|
print(f'{save_path} already exists. Collecting merged weights from existing weights...') |
|
|
|
else: |
|
os.makedirs(save_path) |
|
json_path = os.path.join(save_path,'merge_config.json') |
|
alpha = 1.8 |
|
data = [ |
|
{ |
|
"lora_path": f"experiments/single-concept/{c1}/models/edlora_model-latest.pth", |
|
"unet_alpha": alpha, |
|
"text_encoder_alpha": alpha, |
|
"concept_name": f"<{c1}1> <{c1}2>" |
|
}, |
|
{ |
|
"lora_path": f"experiments/single-concept/{c2}/models/edlora_model-latest.pth", |
|
"unet_alpha": alpha, |
|
"text_encoder_alpha": alpha, |
|
"concept_name": f"<{c2}1> <{c2}2>" |
|
} |
|
] |
|
with io.open(json_path,'w',encoding='utf8') as outfile: |
|
json.dump(data, outfile, indent = 4, ensure_ascii=False) |
|
|
|
compose_concepts( |
|
concept_cfg=json_path, |
|
optimize_textenc_iters=500, |
|
optimize_unet_iters=50, |
|
pretrained_model_path="nitrosocke/mo-di-diffusion", |
|
save_path=save_path, |
|
suffix='base', |
|
device=device, |
|
) |
|
print(f'Merged weight for {c1}+{c2} saved in {save_path}!\n\n') |
|
|
|
modelbase_path = os.path.join(save_path,'combined_model_base') |
|
assert os.path.isdir(modelbase_path) |
|
|
|
|
|
return modelbase_path |
|
|
|
def infer(pretrained_model, |
|
prompt, |
|
prompt_rewrite, |
|
negative_prompt='', |
|
seed=16141, |
|
keypose_condition=None, |
|
keypose_adaptor_weight=1.0, |
|
sketch_condition=None, |
|
sketch_adaptor_weight=0.0, |
|
region_sketch_adaptor_weight='', |
|
region_keypose_adaptor_weight='' |
|
): |
|
|
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') |
|
pipe = build_model(pretrained_model, device) |
|
|
|
if sketch_condition is not None and os.path.exists(sketch_condition): |
|
sketch_condition = Image.open(sketch_condition).convert('L') |
|
width_sketch, height_sketch = sketch_condition.size |
|
print('use sketch condition') |
|
else: |
|
sketch_condition, width_sketch, height_sketch = None, 0, 0 |
|
print('skip sketch condition') |
|
|
|
if keypose_condition is not None and os.path.exists(keypose_condition): |
|
keypose_condition = Image.open(keypose_condition).convert('RGB') |
|
width_pose, height_pose = keypose_condition.size |
|
print('use pose condition') |
|
else: |
|
keypose_condition, width_pose, height_pose = None, 0, 0 |
|
print('skip pose condition') |
|
|
|
if width_sketch != 0 and width_pose != 0: |
|
assert width_sketch == width_pose and height_sketch == height_pose, 'conditions should be same size' |
|
width, height = max(width_pose, width_sketch), max(height_pose, height_sketch) |
|
kwargs = { |
|
'sketch_condition': sketch_condition, |
|
'keypose_condition': keypose_condition, |
|
'height': height, |
|
'width': width, |
|
} |
|
|
|
prompts = [prompt] |
|
prompts_rewrite = [prompt_rewrite] |
|
input_prompt = [prepare_text(p, p_w, height, width) for p, p_w in zip(prompts, prompts_rewrite)] |
|
save_prompt = input_prompt[0][0] |
|
print(save_prompt) |
|
|
|
image = sample_image( |
|
pipe, |
|
input_prompt=input_prompt, |
|
input_neg_prompt=[negative_prompt] * len(input_prompt), |
|
generator=torch.Generator(device).manual_seed(seed), |
|
sketch_adaptor_weight=sketch_adaptor_weight, |
|
region_sketch_adaptor_weight=region_sketch_adaptor_weight, |
|
keypose_adaptor_weight=keypose_adaptor_weight, |
|
region_keypose_adaptor_weight=region_keypose_adaptor_weight, |
|
**kwargs) |
|
|
|
return image[0] |
|
|
|
examples_context = [ |
|
'walking at Stanford university campus', |
|
'in a castle', |
|
'in the forest', |
|
'in front of Eiffel tower' |
|
] |
|
|
|
examples_region1 = ['wearing red hat, high resolution, best quality','bright smile, wearing pants, best quality'] |
|
examples_region2 = ['smilling, wearing blue shirt, high resolution, best quality'] |
|
|
|
css=""" |
|
#col-container { |
|
margin: 0 auto; |
|
max-width: 600px; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=css) as demo: |
|
|
|
with gr.Column(elem_id="col-container"): |
|
gr.Markdown(f""" |
|
# Orthogonal Adaptation |
|
Currently running on {power_device}. |
|
""") |
|
prompt = gr.Text( |
|
label="ContextPrompt", |
|
show_label=False, |
|
max_lines=1, |
|
placeholder="Enter your context prompt for overall image", |
|
container=False, |
|
) |
|
with gr.Row(): |
|
|
|
region1_concept = gr.Dropdown( |
|
["Elsa", "Moana"], |
|
label="Character 1", |
|
info="Will add more characters later!" |
|
) |
|
region2_concept = gr.Dropdown( |
|
["Elsa", "Moana"], |
|
label="Character 2", |
|
info="Will add more characters later!" |
|
) |
|
|
|
with gr.Row(): |
|
|
|
region1_prompt = gr.Textbox( |
|
label="Region1 Prompt", |
|
show_label=False, |
|
max_lines=2, |
|
placeholder="Enter your prompt for character 1", |
|
container=False, |
|
) |
|
|
|
region2_prompt = gr.Textbox( |
|
label="Region2 Prompt", |
|
show_label=False, |
|
max_lines=2, |
|
placeholder="Enter your prompt for character 2", |
|
container=False, |
|
) |
|
|
|
run_button = gr.Button("Run", scale=1) |
|
|
|
result = gr.Image(label="Result", show_label=False) |
|
|
|
with gr.Accordion("Advanced Settings", open=False): |
|
|
|
negative_prompt = gr.Text( |
|
label="Context Negative prompt", |
|
max_lines=1, |
|
value = 'saturated, cropped, worst quality, low quality', |
|
visible=False, |
|
) |
|
|
|
region_neg_prompt = gr.Text( |
|
label="Regional Negative prompt", |
|
max_lines=1, |
|
value = 'shirtless, nudity, saturated, cropped, worst quality, low quality', |
|
visible=False, |
|
) |
|
|
|
seed = gr.Slider( |
|
label="Seed", |
|
minimum=0, |
|
maximum=MAX_SEED, |
|
step=1, |
|
value=0, |
|
) |
|
|
|
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) |
|
|
|
with gr.Row(): |
|
|
|
sketch_adaptor_weight = gr.Slider( |
|
label="Sketch Adapter Weight", |
|
minimum = 0, |
|
maximum = 1, |
|
step=0.01, |
|
value=0, |
|
) |
|
|
|
keypose_adaptor_weight = gr.Slider( |
|
label="Keypose Adapter Weight", |
|
minimum = 0, |
|
maximum = 1, |
|
step= 0.01, |
|
value=1.0, |
|
) |
|
|
|
|
|
gr.Examples( |
|
label = 'Context Prompt example', |
|
examples = examples_context, |
|
inputs = [prompt] |
|
) |
|
|
|
with gr.Row(): |
|
gr.Examples( |
|
label = 'Region1 Prompt example', |
|
examples = examples_region1, |
|
inputs = [region1_prompt] |
|
) |
|
|
|
gr.Examples( |
|
label = 'Region2 Prompt example', |
|
examples = [examples_region2], |
|
inputs = [region2_prompt] |
|
) |
|
|
|
|
|
run_button.click( |
|
fn = generate, |
|
inputs = [region1_concept, |
|
region2_concept, |
|
prompt, |
|
region1_prompt, |
|
region2_prompt, |
|
negative_prompt, |
|
region_neg_prompt, |
|
seed, |
|
randomize_seed, |
|
|
|
|
|
sketch_adaptor_weight, |
|
keypose_adaptor_weight |
|
], |
|
outputs = [result] |
|
) |
|
|
|
demo.queue().launch(share=True) |