#!/usr/bin/env python
"""Demo app for https://github.com/adobe-research/custom-diffusion.
The code in this repo is partly adapted from the following repository:
https://huggingface.co/spaces/hysts/LoRA-SD-training
MIT License
Copyright (c) 2022 hysts
==========================================================================================
Adobe’s modifications are Copyright 2022 Adobe Research. All rights reserved.
Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
LICENSE.
==========================================================================================
"""
from __future__ import annotations
import sys
import os
import pathlib
import gradio as gr
import torch
from inference import inference_fn
# from inference_custom_diffusion import InferencePipeline
# from trainer import Trainer
# from uploader import upload
TITLE = '# Custom Diffusion + StableDiffusion Training UI'
DESCRIPTION = '''This is a demo for [https://github.com/adobe-research/custom-diffusion](https://github.com/adobe-research/custom-diffusion).
It is recommended to upgrade to GPU in Settings after duplicating this space to use it.
'''
DETAILDESCRIPTION='''
Custom Diffusion allows you to fine-tune text-to-image diffusion models, such as Stable Diffusion, given a few images of a new concept (~4-20).
We fine-tune only a subset of model parameters, namely key and value projection matrices, in the cross-attention layers and the modifier token used to represent the object.
This also reduces the extra storage for each additional concept to 75MB. Our method also allows you to use a combination of concepts. There's still limitations on which compositions work. For more analysis please refer to our [website](https://www.cs.cmu.edu/~custom-diffusion/).
'''
ORIGINAL_SPACE_ID = 'Ziqi/ReVersion'
SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
SHARED_UI_WARNING = f'''# Attention - This Space doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU.
'''
if os.getenv('SYSTEM') == 'spaces' and SPACE_ID != ORIGINAL_SPACE_ID:
SETTINGS = f'Settings'
else:
SETTINGS = 'Settings'
CUDA_NOT_AVAILABLE_WARNING = f'''# Attention - Running on CPU.
You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
"T4 small" is sufficient to run this demo.
'''
os.system("git clone https://github.com/ziqihuangg/ReVersion")
sys.path.append("ReVersion")
def show_warning(warning_text: str) -> gr.Blocks:
with gr.Blocks() as demo:
with gr.Box():
gr.Markdown(warning_text)
return demo
def update_output_files() -> dict:
paths = sorted(pathlib.Path('results').glob('*.bin'))
paths = [path.as_posix() for path in paths] # type: ignore
return gr.update(value=paths or None)
def find_weight_files() -> list[str]:
curr_dir = pathlib.Path(__file__).parent
paths = sorted(curr_dir.rglob('*.bin'))
paths = [path for path in paths if '.lfs' not in str(path)]
return [path.relative_to(curr_dir).as_posix() for path in paths]
def reload_custom_diffusion_weight_list() -> dict:
return gr.update(choices=find_weight_files())
def create_inference_demo(func: inference_fn) -> gr.Blocks:
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
model_id = gr.Dropdown(
choices=['experiments/painted_on'],
value='experiments/painted_on',
label='Relation',
visible=True)
reload_button = gr.Button('Reload Weight List')
prompt = gr.Textbox(
label='Prompt',
max_lines=1,
placeholder='Example: "cat stone"')
placeholder_string = gr.Textbox(
label='Placeholder String',
max_lines=1,
placeholder='Example: ""')
with gr.Accordion('Other Parameters', open=False):
guidance_scale = gr.Slider(label='Classifier-Free Guidance Scale',
minimum=0,
maximum=50,
step=0.1,
value=7.5)
num_samples = gr.Slider(label='Batch Size',
minimum=0,
maximum=10.,
step=1,
value=10)
run_button = gr.Button('Generate')
gr.Markdown('''
- Models with names starting with "custom-diffusion-models/" are the pretrained models provided in the [original repo](https://github.com/adobe-research/custom-diffusion), and the ones with names starting with "results/delta.bin" are your trained models.
- After training, you can press "Reload Weight List" button to load your trained model names.
- Increase number of steps in Other parameters for better samples qualitatively.
''')
with gr.Column():
result = gr.Image(label='Result')
# reload_button.click(fn=reload_custom_diffusion_weight_list,
# inputs=None,
# outputs=weight_name)
prompt.submit(fn=func,
inputs=[
model_id,
prompt,
placeholder_string,
guidance_scale
],
outputs=result,
queue=False)
run_button.click(fn=func,
inputs=[
model_id,
prompt,
placeholder_string,
guidance_scale
],
outputs=result,
queue=False)
return demo
with gr.Blocks(css='style.css') as demo:
if os.getenv('IS_SHARED_UI'):
show_warning(SHARED_UI_WARNING)
if not torch.cuda.is_available():
show_warning(CUDA_NOT_AVAILABLE_WARNING)
gr.Markdown(TITLE)
gr.Markdown(DESCRIPTION)
gr.Markdown(DETAILDESCRIPTION)
with gr.Tabs():
with gr.TabItem('Test'):
create_inference_demo(inference_fn)
demo.queue(default_enabled=False).launch(share=False)