#!/usr/bin/env python
"""Demo app for https://github.com/adobe-research/custom-diffusion.
The code in this repo is partly adapted from the following repository:
https://huggingface.co/spaces/hysts/LoRA-SD-training
MIT License
Copyright (c) 2022 hysts
==========================================================================================
Adobe’s modifications are Copyright 2022 Adobe Research. All rights reserved.
Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
LICENSE.
==========================================================================================
"""
from __future__ import annotations
import sys
import os
import pathlib
import gradio as gr
import torch
from inference import InferencePipeline
# from trainer import Trainer
# from uploader import upload
TITLE = '# Custom Diffusion + StableDiffusion Training UI'
DESCRIPTION = '''This is a demo for [https://github.com/adobe-research/custom-diffusion](https://github.com/adobe-research/custom-diffusion).
It is recommended to upgrade to GPU in Settings after duplicating this space to use it.
'''
DETAILDESCRIPTION='''
Custom Diffusion allows you to fine-tune text-to-image diffusion models, such as Stable Diffusion, given a few images of a new concept (~4-20).
We fine-tune only a subset of model parameters, namely key and value projection matrices, in the cross-attention layers and the modifier token used to represent the object.
This also reduces the extra storage for each additional concept to 75MB. Our method also allows you to use a combination of concepts. There's still limitations on which compositions work. For more analysis please refer to our [website](https://www.cs.cmu.edu/~custom-diffusion/).
'''
ORIGINAL_SPACE_ID = 'Ziqi/ReVersion'
SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
SHARED_UI_WARNING = f'''# Attention - This Space doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU.
'''
if os.getenv('SYSTEM') == 'spaces' and SPACE_ID != ORIGINAL_SPACE_ID:
SETTINGS = f'Settings'
else:
SETTINGS = 'Settings'
CUDA_NOT_AVAILABLE_WARNING = f'''# Attention - Running on CPU.
You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
"T4 small" is sufficient to run this demo.
'''
os.system("git clone https://github.com/ziqihuangg/ReVersion")
sys.path.append("ReVersion")
def show_warning(warning_text: str) -> gr.Blocks:
with gr.Blocks() as demo:
with gr.Box():
gr.Markdown(warning_text)
return demo
def update_output_files() -> dict:
paths = sorted(pathlib.Path('results').glob('*.bin'))
paths = [path.as_posix() for path in paths] # type: ignore
return gr.update(value=paths or None)
def create_training_demo(trainer: Trainer,
pipe: InferencePipeline) -> gr.Blocks:
with gr.Blocks() as demo:
base_model = gr.Dropdown(
choices=['stabilityai/stable-diffusion-2-1-base', 'CompVis/stable-diffusion-v1-4'],
value='CompVis/stable-diffusion-v1-4',
label='Base Model',
visible=True)
resolution = gr.Dropdown(choices=['512', '768'],
value='512',
label='Resolution',
visible=True)
with gr.Row():
with gr.Box():
concept_images_collection = []
concept_prompt_collection = []
class_prompt_collection = []
buttons_collection = []
delete_collection = []
is_visible = []
maximum_concepts = 3
row = [None] * maximum_concepts
for x in range(maximum_concepts):
ordinal = lambda n: "%d%s" % (n, "tsnrhtdd"[(n // 10 % 10 != 1) * (n % 10 < 4) * n % 10::4])
ordinal_concept = [" cat", " wooden pot", " chair"]
if(x == 0):
visible = True
is_visible.append(gr.State(value=True))
else:
visible = False
is_visible.append(gr.State(value=False))
concept_images_collection.append(gr.Files(label=f'''Upload the images for your {ordinal(x+1) if (x>0) else ""} concept''', visible=visible))
with gr.Column(visible=visible) as row[x]:
concept_prompt_collection.append(
gr.Textbox(label=f'''{ordinal(x+1) if (x>0) else ""} concept prompt ''', max_lines=1,
placeholder=f'''Example: "photo of a {ordinal_concept[x]}"''' )
)
class_prompt_collection.append(
gr.Textbox(label=f'''{ordinal(x+1) if (x>0) else ""} class prompt ''',
max_lines=1, placeholder=f'''Example: "{ordinal_concept[x][7:]}"''')
)
with gr.Row():
if(x < maximum_concepts-1):
buttons_collection.append(gr.Button(value=f"Add {ordinal(x+2)} concept", visible=visible))
if(x > 0):
delete_collection.append(gr.Button(value=f"Delete {ordinal(x+1)} concept"))
counter_add = 1
for button in buttons_collection:
if(counter_add < len(buttons_collection)):
button.click(lambda:
[gr.update(visible=True),gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), True, None],
None,
[row[counter_add], concept_images_collection[counter_add], buttons_collection[counter_add-1], buttons_collection[counter_add], is_visible[counter_add], concept_images_collection[counter_add]], queue=False)
else:
button.click(lambda:
[gr.update(visible=True),gr.update(visible=True), gr.update(visible=False), True],
None,
[row[counter_add], concept_images_collection[counter_add], buttons_collection[counter_add-1], is_visible[counter_add]], queue=False)
counter_add += 1
counter_delete = 1
for delete_button in delete_collection:
if(counter_delete < len(delete_collection)+1):
if counter_delete == 1:
delete_button.click(lambda:
[gr.update(visible=False, value=None),gr.update(visible=False), gr.update(visible=True), gr.update(visible=False),False],
None,
[concept_images_collection[counter_delete], row[counter_delete], buttons_collection[counter_delete-1], buttons_collection[counter_delete], is_visible[counter_delete]], queue=False)
else:
delete_button.click(lambda:
[gr.update(visible=False, value=None),gr.update(visible=False), gr.update(visible=True), False],
None,
[concept_images_collection[counter_delete], row[counter_delete], buttons_collection[counter_delete-1], is_visible[counter_delete]], queue=False)
counter_delete += 1
gr.Markdown('''
- We use "\" modifier_token in front of the concept, e.g., "\ cat". For multiple concepts use "\", "\" etc. Increase the number of steps with more concepts.
- For a new concept an e.g. concept prompt is "photo of a \ cat" and "cat" for class prompt.
- For a style concept, use "painting in the style of \ art" for concept prompt and "art" for class prompt.
- Class prompt should be the object category.
- If "Train Text Encoder", disable "modifier token" and use any unique text to describe the concept e.g. "ktn cat".
''')
with gr.Box():
gr.Markdown('Training Parameters')
with gr.Row():
modifier_token = gr.Checkbox(label='modifier token',
value=True)
train_text_encoder = gr.Checkbox(label='Train Text Encoder',
value=False)
num_training_steps = gr.Number(
label='Number of Training Steps', value=1000, precision=0)
learning_rate = gr.Number(label='Learning Rate', value=0.00001)
batch_size = gr.Number(
label='batch_size', value=1, precision=0)
with gr.Row():
use_8bit_adam = gr.Checkbox(label='Use 8bit Adam', value=True)
gradient_checkpointing = gr.Checkbox(label='Enable gradient checkpointing', value=False)
with gr.Accordion('Other Parameters', open=False):
gradient_accumulation = gr.Number(
label='Number of Gradient Accumulation',
value=1,
precision=0)
num_reg_images = gr.Number(
label='Number of Class Concept images',
value=200,
precision=0)
gen_images = gr.Checkbox(label='Generated images as regularization',
value=False)
gr.Markdown('''
- It will take about ~10 minutes to train for 1000 steps and ~21GB on a 3090 GPU.
- Our results in the paper are trained with batch-size 4 (8 including class regularization samples).
- Enable gradient checkpointing for lower memory requirements (~14GB) at the expense of slower backward pass.
- Note that your trained models will be deleted when the second training is started. You can upload your trained model in the "Upload" tab.
- We retrieve real images for class concept using clip_retireval library which can take some time.
''')
run_button = gr.Button('Start Training')
with gr.Box():
with gr.Row():
check_status_button = gr.Button('Check Training Status')
with gr.Column():
with gr.Box():
gr.Markdown('Message')
training_status = gr.Markdown()
output_files = gr.Files(label='Trained Weight Files')
run_button.click(fn=pipe.clear,
inputs=None,
outputs=None,)
run_button.click(fn=trainer.run,
inputs=[
base_model,
resolution,
num_training_steps,
learning_rate,
train_text_encoder,
modifier_token,
gradient_accumulation,
batch_size,
use_8bit_adam,
gradient_checkpointing,
gen_images,
num_reg_images,
] +
concept_images_collection +
concept_prompt_collection +
class_prompt_collection
,
outputs=[
training_status,
output_files,
],
queue=False)
check_status_button.click(fn=trainer.check_if_running,
inputs=None,
outputs=training_status,
queue=False)
check_status_button.click(fn=update_output_files,
inputs=None,
outputs=output_files,
queue=False)
return demo
def find_weight_files() -> list[str]:
curr_dir = pathlib.Path(__file__).parent
paths = sorted(curr_dir.rglob('*.bin'))
paths = [path for path in paths if '.lfs' not in str(path)]
return [path.relative_to(curr_dir).as_posix() for path in paths]
def reload_custom_diffusion_weight_list() -> dict:
return gr.update(choices=find_weight_files())
def create_inference_demo(pipe: InferencePipeline) -> gr.Blocks:
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
base_model = gr.Dropdown(
choices=['stabilityai/stable-diffusion-2-1-base', 'CompVis/stable-diffusion-v1-4'],
value='CompVis/stable-diffusion-v1-4',
label='Base Model',
visible=True)
resolution = gr.Dropdown(choices=[512, 768],
value=512,
label='Resolution',
visible=True)
reload_button = gr.Button('Reload Weight List')
weight_name = gr.Dropdown(choices=find_weight_files(),
value='custom-diffusion-models/cat.bin',
label='Custom Diffusion Weight File')
prompt = gr.Textbox(
label='Prompt',
max_lines=1,
placeholder='Example: "\ cat in outer space"')
seed = gr.Slider(label='Seed',
minimum=0,
maximum=100000,
step=1,
value=42)
with gr.Accordion('Other Parameters', open=False):
num_steps = gr.Slider(label='Number of Steps',
minimum=0,
maximum=500,
step=1,
value=100)
guidance_scale = gr.Slider(label='CFG Scale',
minimum=0,
maximum=50,
step=0.1,
value=6)
eta = gr.Slider(label='DDIM eta',
minimum=0,
maximum=1.,
step=0.1,
value=1.)
batch_size = gr.Slider(label='Batch Size',
minimum=0,
maximum=10.,
step=1,
value=1)
run_button = gr.Button('Generate')
gr.Markdown('''
- Models with names starting with "custom-diffusion-models/" are the pretrained models provided in the [original repo](https://github.com/adobe-research/custom-diffusion), and the ones with names starting with "results/delta.bin" are your trained models.
- After training, you can press "Reload Weight List" button to load your trained model names.
- Increase number of steps in Other parameters for better samples qualitatively.
''')
with gr.Column():
result = gr.Image(label='Result')
reload_button.click(fn=reload_custom_diffusion_weight_list,
inputs=None,
outputs=weight_name)
prompt.submit(fn=pipe.run,
inputs=[
base_model,
weight_name,
prompt,
seed,
num_steps,
guidance_scale,
eta,
batch_size,
resolution
],
outputs=result,
queue=False)
run_button.click(fn=pipe.run,
inputs=[
base_model,
weight_name,
prompt,
seed,
num_steps,
guidance_scale,
eta,
batch_size,
resolution
],
outputs=result,
queue=False)
return demo
def create_upload_demo() -> gr.Blocks:
with gr.Blocks() as demo:
model_name = gr.Textbox(label='Model Name')
hf_token = gr.Textbox(
label='Hugging Face Token (with write permission)')
upload_button = gr.Button('Upload')
with gr.Box():
gr.Markdown('Message')
result = gr.Markdown()
gr.Markdown('''
- You can upload your trained model to your private Model repo (i.e. https://huggingface.co/{your_username}/{model_name}).
- You can find your Hugging Face token [here](https://huggingface.co/settings/tokens).
''')
upload_button.click(fn=upload,
inputs=[model_name, hf_token],
outputs=result)
return demo
pipe = InferencePipeline()
trainer = Trainer()
with gr.Blocks(css='style.css') as demo:
if os.getenv('IS_SHARED_UI'):
show_warning(SHARED_UI_WARNING)
if not torch.cuda.is_available():
show_warning(CUDA_NOT_AVAILABLE_WARNING)
gr.Markdown(TITLE)
gr.Markdown(DESCRIPTION)
gr.Markdown(DETAILDESCRIPTION)
with gr.Tabs():
with gr.TabItem('Train'):
create_training_demo(trainer, pipe)
with gr.TabItem('Test'):
create_inference_demo(pipe)
with gr.TabItem('Upload'):
create_upload_demo()
demo.queue(default_enabled=False).launch(share=False)