|
import re |
|
import os |
|
import yaml |
|
import tempfile |
|
import subprocess |
|
from pathlib import Path |
|
|
|
import torch |
|
import gradio as gr |
|
|
|
from src.flux.xflux_pipeline import XFluxPipeline |
|
|
|
|
|
def list_dirs(path): |
|
if path is None or path == "None" or path == "": |
|
return |
|
|
|
if not os.path.exists(path): |
|
path = os.path.dirname(path) |
|
if not os.path.exists(path): |
|
return |
|
|
|
if not os.path.isdir(path): |
|
path = os.path.dirname(path) |
|
|
|
def natural_sort_key(s, regex=re.compile("([0-9]+)")): |
|
return [ |
|
int(text) if text.isdigit() else text.lower() for text in regex.split(s) |
|
] |
|
|
|
subdirs = [ |
|
(item, os.path.join(path, item)) |
|
for item in os.listdir(path) |
|
if os.path.isdir(os.path.join(path, item)) |
|
] |
|
subdirs = [ |
|
filename |
|
for item, filename in subdirs |
|
if item[0] != "." and item not in ["__pycache__"] |
|
] |
|
subdirs = sorted(subdirs, key=natural_sort_key) |
|
if os.path.dirname(path) != "": |
|
dirs = [os.path.dirname(path), path] + subdirs |
|
else: |
|
dirs = [path] + subdirs |
|
|
|
if os.sep == "\\": |
|
dirs = [d.replace("\\", "/") for d in dirs] |
|
for d in dirs: |
|
yield d |
|
|
|
def list_train_data_dirs(): |
|
current_train_data_dir = "." |
|
return list(list_dirs(current_train_data_dir)) |
|
|
|
def update_config(d, u): |
|
for k, v in u.items(): |
|
if isinstance(v, dict): |
|
d[k] = update_config(d.get(k, {}), v) |
|
else: |
|
|
|
if hasattr(v, 'value'): |
|
d[k] = str(v.value) |
|
else: |
|
try: |
|
d[k] = int(v) |
|
except (TypeError, ValueError): |
|
d[k] = str(v) |
|
return d |
|
|
|
def start_lora_training( |
|
data_dir: str, output_dir: str, lr: float, steps: int, rank: int |
|
): |
|
inputs = { |
|
"data_config": { |
|
"img_dir": data_dir, |
|
}, |
|
"output_dir": output_dir, |
|
"learning_rate": lr, |
|
"rank": rank, |
|
"max_train_steps": steps, |
|
} |
|
|
|
if not os.path.exists(output_dir): |
|
os.makedirs(output_dir) |
|
print(f"Creating folder {output_dir} for the output checkpoint file...") |
|
|
|
script_path = Path(__file__).resolve() |
|
config_path = script_path.parent / "train_configs" / "test_lora.yaml" |
|
with open(config_path, 'r') as file: |
|
config = yaml.safe_load(file) |
|
|
|
config = update_config(config, inputs) |
|
print("Config file is updated...", config) |
|
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".yaml") as temp_file: |
|
yaml.dump(config, temp_file, default_flow_style=False) |
|
tmp_config_path = temp_file.name |
|
|
|
command = ["accelerate", "launch", "train_flux_lora_deepspeed.py", "--config", tmp_config_path] |
|
result = subprocess.run(command, check=True) |
|
|
|
|
|
Path(tmp_config_path).unlink() |
|
|
|
return result |
|
|
|
|
|
def create_demo( |
|
model_type: str, |
|
device: str = "cuda" if torch.cuda.is_available() else "cpu", |
|
offload: bool = False, |
|
ckpt_dir: str = "", |
|
): |
|
xflux_pipeline = XFluxPipeline(model_type, device, offload) |
|
checkpoints = sorted(Path(ckpt_dir).glob("*.safetensors")) |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(f"# Flux Adapters by XLabs AI - Model: {model_type}") |
|
with gr.Tab("Inference"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
prompt = gr.Textbox(label="Prompt", value="handsome woman in the city") |
|
|
|
with gr.Accordion("Generation Options", open=False): |
|
with gr.Row(): |
|
width = gr.Slider(512, 2048, 1024, step=16, label="Width") |
|
height = gr.Slider(512, 2048, 1024, step=16, label="Height") |
|
neg_prompt = gr.Textbox(label="Negative Prompt", value="bad photo") |
|
with gr.Row(): |
|
num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps") |
|
timestep_to_start_cfg = gr.Slider(1, 50, 1, step=1, label="timestep_to_start_cfg") |
|
with gr.Row(): |
|
guidance = gr.Slider(1.0, 5.0, 4.0, step=0.1, label="Guidance", interactive=True) |
|
true_gs = gr.Slider(1.0, 5.0, 3.5, step=0.1, label="True Guidance", interactive=True) |
|
seed = gr.Textbox(-1, label="Seed (-1 for random)") |
|
|
|
with gr.Accordion("ControlNet Options", open=False): |
|
control_type = gr.Dropdown(["canny", "hed", "depth"], label="Control type") |
|
control_weight = gr.Slider(0.0, 1.0, 0.8, step=0.1, label="Controlnet weight", interactive=True) |
|
local_path = gr.Dropdown(checkpoints, label="Controlnet Checkpoint", |
|
info="Local Path to Controlnet weights (if no, it will be downloaded from HF)" |
|
) |
|
controlnet_image = gr.Image(label="Input Controlnet Image", visible=True, interactive=True) |
|
|
|
with gr.Accordion("LoRA Options", open=False): |
|
lora_weight = gr.Slider(0.0, 1.0, 0.9, step=0.1, label="LoRA weight", interactive=True) |
|
lora_local_path = gr.Dropdown( |
|
checkpoints, label="LoRA Checkpoint", info="Local Path to Lora weights" |
|
) |
|
|
|
with gr.Accordion("IP Adapter Options", open=False): |
|
image_prompt = gr.Image(label="image_prompt", visible=True, interactive=True) |
|
ip_scale = gr.Slider(0.0, 1.0, 1.0, step=0.1, label="ip_scale") |
|
neg_image_prompt = gr.Image(label="neg_image_prompt", visible=True, interactive=True) |
|
neg_ip_scale = gr.Slider(0.0, 1.0, 1.0, step=0.1, label="neg_ip_scale") |
|
ip_local_path = gr.Dropdown( |
|
checkpoints, label="IP Adapter Checkpoint", |
|
info="Local Path to IP Adapter weights (if no, it will be downloaded from HF)" |
|
) |
|
generate_btn = gr.Button("Generate") |
|
|
|
with gr.Column(): |
|
output_image = gr.Image(label="Generated Image") |
|
download_btn = gr.File(label="Download full-resolution") |
|
|
|
inputs = [prompt, image_prompt, controlnet_image, width, height, guidance, |
|
num_steps, seed, true_gs, ip_scale, neg_ip_scale, neg_prompt, |
|
neg_image_prompt, timestep_to_start_cfg, control_type, control_weight, |
|
lora_weight, local_path, lora_local_path, ip_local_path |
|
] |
|
generate_btn.click( |
|
fn=xflux_pipeline.gradio_generate, |
|
inputs=inputs, |
|
outputs=[output_image, download_btn], |
|
) |
|
|
|
with gr.Tab("LoRA Finetuning"): |
|
data_dir = gr.Dropdown(list_train_data_dirs(), |
|
label="Training images (directory containing the training images)" |
|
) |
|
output_dir = gr.Textbox(label="Output Path", value="lora_checkpoint") |
|
|
|
with gr.Accordion("Training Options", open=True): |
|
lr = gr.Textbox(label="Learning Rate", value="1e-5") |
|
steps = gr.Slider(10000, 20000, 20000, step=100, label="Train Steps") |
|
rank = gr.Slider(1, 100, 16, step=1, label="LoRa Rank") |
|
|
|
training_btn = gr.Button("Start training") |
|
training_btn.click( |
|
fn=start_lora_training, |
|
inputs=[data_dir, output_dir, lr, steps, rank], |
|
outputs=[], |
|
) |
|
|
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
parser = argparse.ArgumentParser(description="Flux") |
|
parser.add_argument("--name", type=str, default="flux-dev", help="Model name") |
|
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use") |
|
parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use") |
|
parser.add_argument("--share", action="store_true", help="Create a public link to your demo") |
|
parser.add_argument("--ckpt_dir", type=str, default=".", help="Folder with checkpoints in safetensors format") |
|
args = parser.parse_args() |
|
|
|
demo = create_demo(args.name, args.device, args.offload, args.ckpt_dir) |
|
demo.launch(share=args.share) |