ortha / app.py
ujin-song's picture
Update app.py -- first released version
57d7bf6 verified
raw
history blame
10.4 kB
import gradio as gr
import numpy as np
import random
import torch
import io, json
from PIL import Image
import os.path
from weight_fusion import compose_concepts
from regionally_controlable_sampling import sample_image, build_model, prepare_text
device = "cuda" if torch.cuda.is_available() else "cpu"
power_device = "GPU" if torch.cuda.is_available() else "CPU"
MAX_SEED = 100_000
def generate(region1_concept,
region2_concept,
prompt,
region1_prompt,
region2_prompt,
negative_prompt,
region_neg_prompt,
seed,
randomize_seed,
sketch_adaptor_weight,
keypose_adaptor_weight
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
region1_concept, region2_concept = region1_concept.lower(), region2_concept.lower()
pretrained_model = merge(region1_concept, region2_concept)
keypose_condition = 'multi-concept/pose_data/two_apart.png'
region1 = '[0, 0, 512, 290]'
region2 = '[0, 650, 512, 910]'
region1_prompt = f'[<{region1_concept}1> <{region1_concept}2>, {region1_prompt}]'
region2_prompt = f'[<{region2_concept}1> <{region2_concept}2>, {region2_prompt}]'
prompt_rewrite=f"{region1_prompt}-*-{region_neg_prompt}-*-{region1}|{region2_prompt}-*-{region_neg_prompt}-*-{region2}"
result = infer(pretrained_model,
prompt,
prompt_rewrite,
negative_prompt,
seed,
keypose_condition,
keypose_adaptor_weight,
# sketch_condition,
# sketch_adaptor_weight,
)
return result
def merge(concept1, concept2):
device = "cuda" if torch.cuda.is_available() else "cpu"
c1, c2 = sorted([concept1, concept2])
assert c1!=c2
merge_name = c1+'_'+c2
save_path = f'experiments/multi-concept/{merge_name}'
if os.path.isdir(save_path):
print(f'{save_path} already exists. Collecting merged weights from existing weights...')
else:
os.makedirs(save_path)
json_path = os.path.join(save_path,'merge_config.json')
alpha = 1.8
data = [
{
"lora_path": f"experiments/single-concept/{c1}/models/edlora_model-latest.pth",
"unet_alpha": alpha,
"text_encoder_alpha": alpha,
"concept_name": f"<{c1}1> <{c1}2>"
},
{
"lora_path": f"experiments/single-concept/{c2}/models/edlora_model-latest.pth",
"unet_alpha": alpha,
"text_encoder_alpha": alpha,
"concept_name": f"<{c2}1> <{c2}2>"
}
]
with io.open(json_path,'w',encoding='utf8') as outfile:
json.dump(data, outfile, indent = 4, ensure_ascii=False)
compose_concepts(
concept_cfg=json_path,
optimize_textenc_iters=500,
optimize_unet_iters=50,
pretrained_model_path="nitrosocke/mo-di-diffusion",
save_path=save_path,
suffix='base',
device=device,
)
print(f'Merged weight for {c1}+{c2} saved in {save_path}!\n\n')
modelbase_path = os.path.join(save_path,'combined_model_base')
assert os.path.isdir(modelbase_path)
# save_path = 'experiments/multi-concept/elsa_moana_weight18/combined_model_base'
return modelbase_path
def infer(pretrained_model,
prompt,
prompt_rewrite,
negative_prompt='',
seed=16141,
keypose_condition=None,
keypose_adaptor_weight=1.0,
sketch_condition=None,
sketch_adaptor_weight=0.0,
region_sketch_adaptor_weight='',
region_keypose_adaptor_weight=''
):
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
pipe = build_model(pretrained_model, device)
if sketch_condition is not None and os.path.exists(sketch_condition):
sketch_condition = Image.open(sketch_condition).convert('L')
width_sketch, height_sketch = sketch_condition.size
print('use sketch condition')
else:
sketch_condition, width_sketch, height_sketch = None, 0, 0
print('skip sketch condition')
if keypose_condition is not None and os.path.exists(keypose_condition):
keypose_condition = Image.open(keypose_condition).convert('RGB')
width_pose, height_pose = keypose_condition.size
print('use pose condition')
else:
keypose_condition, width_pose, height_pose = None, 0, 0
print('skip pose condition')
if width_sketch != 0 and width_pose != 0:
assert width_sketch == width_pose and height_sketch == height_pose, 'conditions should be same size'
width, height = max(width_pose, width_sketch), max(height_pose, height_sketch)
kwargs = {
'sketch_condition': sketch_condition,
'keypose_condition': keypose_condition,
'height': height,
'width': width,
}
prompts = [prompt]
prompts_rewrite = [prompt_rewrite]
input_prompt = [prepare_text(p, p_w, height, width) for p, p_w in zip(prompts, prompts_rewrite)]
save_prompt = input_prompt[0][0]
print(save_prompt)
image = sample_image(
pipe,
input_prompt=input_prompt,
input_neg_prompt=[negative_prompt] * len(input_prompt),
generator=torch.Generator(device).manual_seed(seed),
sketch_adaptor_weight=sketch_adaptor_weight,
region_sketch_adaptor_weight=region_sketch_adaptor_weight,
keypose_adaptor_weight=keypose_adaptor_weight,
region_keypose_adaptor_weight=region_keypose_adaptor_weight,
**kwargs)
return image[0]
examples_context = [
'walking at Stanford university campus',
'in a castle',
'in the forest',
'in front of Eiffel tower'
]
examples_region1 = ['wearing red hat, high resolution, best quality','bright smile, wearing pants, best quality']
examples_region2 = ['smilling, wearing blue shirt, high resolution, best quality']
css="""
#col-container {
margin: 0 auto;
max-width: 600px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""
# Orthogonal Adaptation
Currently running on {power_device}.
""")
prompt = gr.Text(
label="ContextPrompt",
show_label=False,
max_lines=1,
placeholder="Enter your context prompt for overall image",
container=False,
)
with gr.Row():
region1_concept = gr.Dropdown(
["Elsa", "Moana"],
label="Character 1",
info="Will add more characters later!"
)
region2_concept = gr.Dropdown(
["Elsa", "Moana"],
label="Character 2",
info="Will add more characters later!"
)
with gr.Row():
region1_prompt = gr.Textbox(
label="Region1 Prompt",
show_label=False,
max_lines=2,
placeholder="Enter your prompt for character 1",
container=False,
)
region2_prompt = gr.Textbox(
label="Region2 Prompt",
show_label=False,
max_lines=2,
placeholder="Enter your prompt for character 2",
container=False,
)
run_button = gr.Button("Run", scale=1)
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(
label="Context Negative prompt",
max_lines=1,
value = 'saturated, cropped, worst quality, low quality',
visible=False,
)
region_neg_prompt = gr.Text(
label="Regional Negative prompt",
max_lines=1,
value = 'shirtless, nudity, saturated, cropped, worst quality, low quality',
visible=False,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
sketch_adaptor_weight = gr.Slider(
label="Sketch Adapter Weight",
minimum = 0,
maximum = 1,
step=0.01,
value=0,
)
keypose_adaptor_weight = gr.Slider(
label="Keypose Adapter Weight",
minimum = 0,
maximum = 1,
step= 0.01,
value=1.0,
)
gr.Examples(
label = 'Context Prompt example',
examples = examples_context,
inputs = [prompt]
)
with gr.Row():
gr.Examples(
label = 'Region1 Prompt example',
examples = examples_region1,
inputs = [region1_prompt]
)
gr.Examples(
label = 'Region2 Prompt example',
examples = [examples_region2],
inputs = [region2_prompt]
)
run_button.click(
fn = generate,
inputs = [region1_concept,
region2_concept,
prompt,
region1_prompt,
region2_prompt,
negative_prompt,
region_neg_prompt,
seed,
randomize_seed,
# sketch_condition,
# keypose_condition,
sketch_adaptor_weight,
keypose_adaptor_weight
],
outputs = [result]
)
demo.queue().launch(share=True)