Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
"""Demo app for https://github.com/adobe-research/custom-diffusion. | |
The code in this repo is partly adapted from the following repository: | |
https://huggingface.co/spaces/hysts/LoRA-SD-training | |
MIT License | |
Copyright (c) 2022 hysts | |
========================================================================================== | |
Adobe’s modifications are Copyright 2022 Adobe Research. All rights reserved. | |
Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit | |
LICENSE. | |
========================================================================================== | |
""" | |
from __future__ import annotations | |
import sys | |
import os | |
import pathlib | |
import gradio as gr | |
import torch | |
from inference import inference_fn | |
# from inference_custom_diffusion import InferencePipeline | |
# from trainer import Trainer | |
# from uploader import upload | |
TITLE = '# Custom Diffusion + StableDiffusion Training UI' | |
DESCRIPTION = '''This is a demo for [https://github.com/adobe-research/custom-diffusion](https://github.com/adobe-research/custom-diffusion). | |
It is recommended to upgrade to GPU in Settings after duplicating this space to use it. | |
<a href="https://huggingface.co/spaces/nupurkmr9/custom-diffusion?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a> | |
''' | |
DETAILDESCRIPTION=''' | |
Custom Diffusion allows you to fine-tune text-to-image diffusion models, such as Stable Diffusion, given a few images of a new concept (~4-20). | |
We fine-tune only a subset of model parameters, namely key and value projection matrices, in the cross-attention layers and the modifier token used to represent the object. | |
This also reduces the extra storage for each additional concept to 75MB. Our method also allows you to use a combination of concepts. There's still limitations on which compositions work. For more analysis please refer to our [website](https://www.cs.cmu.edu/~custom-diffusion/). | |
<center> | |
<img src="https://huggingface.co/spaces/nupurkmr9/custom-diffusion/resolve/main/method.jpg" width="600" align="center" > | |
</center> | |
''' | |
ORIGINAL_SPACE_ID = 'Ziqi/ReVersion' | |
SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID) | |
SHARED_UI_WARNING = f'''# Attention - This Space doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU. | |
<center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></center> | |
''' | |
if os.getenv('SYSTEM') == 'spaces' and SPACE_ID != ORIGINAL_SPACE_ID: | |
SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>' | |
else: | |
SETTINGS = 'Settings' | |
CUDA_NOT_AVAILABLE_WARNING = f'''# Attention - Running on CPU. | |
<center> | |
You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces. | |
"T4 small" is sufficient to run this demo. | |
</center> | |
''' | |
os.system("git clone https://github.com/ziqihuangg/ReVersion") | |
sys.path.append("ReVersion") | |
def show_warning(warning_text: str) -> gr.Blocks: | |
with gr.Blocks() as demo: | |
with gr.Box(): | |
gr.Markdown(warning_text) | |
return demo | |
def update_output_files() -> dict: | |
paths = sorted(pathlib.Path('results').glob('*.bin')) | |
paths = [path.as_posix() for path in paths] # type: ignore | |
return gr.update(value=paths or None) | |
def find_weight_files() -> list[str]: | |
curr_dir = pathlib.Path(__file__).parent | |
paths = sorted(curr_dir.rglob('*.bin')) | |
paths = [path for path in paths if '.lfs' not in str(path)] | |
return [path.relative_to(curr_dir).as_posix() for path in paths] | |
def reload_custom_diffusion_weight_list() -> dict: | |
return gr.update(choices=find_weight_files()) | |
def create_inference_demo(func: inference_fn) -> gr.Blocks: | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
model_id = gr.Dropdown( | |
choices=['experiments/painted_on'], | |
value='experiments/painted_on', | |
label='Relation', | |
visible=True) | |
reload_button = gr.Button('Reload Weight List') | |
prompt = gr.Textbox( | |
label='Prompt', | |
max_lines=1, | |
placeholder='Example: "cat <R> stone"') | |
placeholder_string = gr.Textbox( | |
label='Placeholder String', | |
max_lines=1, | |
placeholder='Example: "<R>"') | |
with gr.Accordion('Other Parameters', open=False): | |
guidance_scale = gr.Slider(label='Classifier-Free Guidance Scale', | |
minimum=0, | |
maximum=50, | |
step=0.1, | |
value=7.5) | |
num_samples = gr.Slider(label='Batch Size', | |
minimum=0, | |
maximum=10., | |
step=1, | |
value=10) | |
run_button = gr.Button('Generate') | |
gr.Markdown(''' | |
- Models with names starting with "custom-diffusion-models/" are the pretrained models provided in the [original repo](https://github.com/adobe-research/custom-diffusion), and the ones with names starting with "results/delta.bin" are your trained models. | |
- After training, you can press "Reload Weight List" button to load your trained model names. | |
- Increase number of steps in Other parameters for better samples qualitatively. | |
''') | |
with gr.Column(): | |
result = gr.Image(label='Result') | |
# reload_button.click(fn=reload_custom_diffusion_weight_list, | |
# inputs=None, | |
# outputs=weight_name) | |
prompt.submit(fn=func, | |
inputs=[ | |
model_id, | |
prompt, | |
placeholder_string, | |
guidance_scale | |
], | |
outputs=result, | |
queue=False) | |
run_button.click(fn=func, | |
inputs=[ | |
model_id, | |
prompt, | |
placeholder_string, | |
guidance_scale | |
], | |
outputs=result, | |
queue=False) | |
return demo | |
with gr.Blocks(css='style.css') as demo: | |
if os.getenv('IS_SHARED_UI'): | |
show_warning(SHARED_UI_WARNING) | |
if not torch.cuda.is_available(): | |
show_warning(CUDA_NOT_AVAILABLE_WARNING) | |
gr.Markdown(TITLE) | |
gr.Markdown(DESCRIPTION) | |
gr.Markdown(DETAILDESCRIPTION) | |
with gr.Tabs(): | |
with gr.TabItem('Test'): | |
create_inference_demo(inference_fn) | |
demo.queue(default_enabled=False).launch(share=False) | |