DigiP-AI commited on
Commit
ed3b341
·
verified ·
1 Parent(s): 724ca21

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +209 -0
app.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import yaml
4
+ import tempfile
5
+ import subprocess
6
+ from pathlib import Path
7
+
8
+ import torch
9
+ import gradio as gr
10
+
11
+ from src.flux.xflux_pipeline import XFluxPipeline
12
+
13
+
14
+ def list_dirs(path):
15
+ if path is None or path == "None" or path == "":
16
+ return
17
+
18
+ if not os.path.exists(path):
19
+ path = os.path.dirname(path)
20
+ if not os.path.exists(path):
21
+ return
22
+
23
+ if not os.path.isdir(path):
24
+ path = os.path.dirname(path)
25
+
26
+ def natural_sort_key(s, regex=re.compile("([0-9]+)")):
27
+ return [
28
+ int(text) if text.isdigit() else text.lower() for text in regex.split(s)
29
+ ]
30
+
31
+ subdirs = [
32
+ (item, os.path.join(path, item))
33
+ for item in os.listdir(path)
34
+ if os.path.isdir(os.path.join(path, item))
35
+ ]
36
+ subdirs = [
37
+ filename
38
+ for item, filename in subdirs
39
+ if item[0] != "." and item not in ["__pycache__"]
40
+ ]
41
+ subdirs = sorted(subdirs, key=natural_sort_key)
42
+ if os.path.dirname(path) != "":
43
+ dirs = [os.path.dirname(path), path] + subdirs
44
+ else:
45
+ dirs = [path] + subdirs
46
+
47
+ if os.sep == "\\":
48
+ dirs = [d.replace("\\", "/") for d in dirs]
49
+ for d in dirs:
50
+ yield d
51
+
52
+ def list_train_data_dirs():
53
+ current_train_data_dir = "."
54
+ return list(list_dirs(current_train_data_dir))
55
+
56
+ def update_config(d, u):
57
+ for k, v in u.items():
58
+ if isinstance(v, dict):
59
+ d[k] = update_config(d.get(k, {}), v)
60
+ else:
61
+ # convert Gradio components to strings
62
+ if hasattr(v, 'value'):
63
+ d[k] = str(v.value)
64
+ else:
65
+ try:
66
+ d[k] = int(v)
67
+ except (TypeError, ValueError):
68
+ d[k] = str(v)
69
+ return d
70
+
71
+ def start_lora_training(
72
+ data_dir: str, output_dir: str, lr: float, steps: int, rank: int
73
+ ):
74
+ inputs = {
75
+ "data_config": {
76
+ "img_dir": data_dir,
77
+ },
78
+ "output_dir": output_dir,
79
+ "learning_rate": lr,
80
+ "rank": rank,
81
+ "max_train_steps": steps,
82
+ }
83
+
84
+ if not os.path.exists(output_dir):
85
+ os.makedirs(output_dir)
86
+ print(f"Creating folder {output_dir} for the output checkpoint file...")
87
+
88
+ script_path = Path(__file__).resolve()
89
+ config_path = script_path.parent / "train_configs" / "test_lora.yaml"
90
+ with open(config_path, 'r') as file:
91
+ config = yaml.safe_load(file)
92
+
93
+ config = update_config(config, inputs)
94
+ print("Config file is updated...", config)
95
+ with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".yaml") as temp_file:
96
+ yaml.dump(config, temp_file, default_flow_style=False)
97
+ tmp_config_path = temp_file.name
98
+
99
+ command = ["accelerate", "launch", "train_flux_lora_deepspeed.py", "--config", tmp_config_path]
100
+ result = subprocess.run(command, check=True)
101
+
102
+ # rRemove the temporary file after the command is run
103
+ Path(tmp_config_path).unlink()
104
+
105
+ return result
106
+
107
+
108
+ def create_demo(
109
+ model_type: str,
110
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
111
+ offload: bool = False,
112
+ ckpt_dir: str = "",
113
+ ):
114
+ xflux_pipeline = XFluxPipeline(model_type, device, offload)
115
+ checkpoints = sorted(Path(ckpt_dir).glob("*.safetensors"))
116
+
117
+ with gr.Blocks() as demo:
118
+ gr.Markdown(f"# Flux Adapters by XLabs AI - Model: {model_type}")
119
+ with gr.Tab("Inference"):
120
+ with gr.Row():
121
+ with gr.Column():
122
+ prompt = gr.Textbox(label="Prompt", value="handsome woman in the city")
123
+
124
+ with gr.Accordion("Generation Options", open=False):
125
+ with gr.Row():
126
+ width = gr.Slider(512, 2048, 1024, step=16, label="Width")
127
+ height = gr.Slider(512, 2048, 1024, step=16, label="Height")
128
+ neg_prompt = gr.Textbox(label="Negative Prompt", value="bad photo")
129
+ with gr.Row():
130
+ num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps")
131
+ timestep_to_start_cfg = gr.Slider(1, 50, 1, step=1, label="timestep_to_start_cfg")
132
+ with gr.Row():
133
+ guidance = gr.Slider(1.0, 5.0, 4.0, step=0.1, label="Guidance", interactive=True)
134
+ true_gs = gr.Slider(1.0, 5.0, 3.5, step=0.1, label="True Guidance", interactive=True)
135
+ seed = gr.Textbox(-1, label="Seed (-1 for random)")
136
+
137
+ with gr.Accordion("ControlNet Options", open=False):
138
+ control_type = gr.Dropdown(["canny", "hed", "depth"], label="Control type")
139
+ control_weight = gr.Slider(0.0, 1.0, 0.8, step=0.1, label="Controlnet weight", interactive=True)
140
+ local_path = gr.Dropdown(checkpoints, label="Controlnet Checkpoint",
141
+ info="Local Path to Controlnet weights (if no, it will be downloaded from HF)"
142
+ )
143
+ controlnet_image = gr.Image(label="Input Controlnet Image", visible=True, interactive=True)
144
+
145
+ with gr.Accordion("LoRA Options", open=False):
146
+ lora_weight = gr.Slider(0.0, 1.0, 0.9, step=0.1, label="LoRA weight", interactive=True)
147
+ lora_local_path = gr.Dropdown(
148
+ checkpoints, label="LoRA Checkpoint", info="Local Path to Lora weights"
149
+ )
150
+
151
+ with gr.Accordion("IP Adapter Options", open=False):
152
+ image_prompt = gr.Image(label="image_prompt", visible=True, interactive=True)
153
+ ip_scale = gr.Slider(0.0, 1.0, 1.0, step=0.1, label="ip_scale")
154
+ neg_image_prompt = gr.Image(label="neg_image_prompt", visible=True, interactive=True)
155
+ neg_ip_scale = gr.Slider(0.0, 1.0, 1.0, step=0.1, label="neg_ip_scale")
156
+ ip_local_path = gr.Dropdown(
157
+ checkpoints, label="IP Adapter Checkpoint",
158
+ info="Local Path to IP Adapter weights (if no, it will be downloaded from HF)"
159
+ )
160
+ generate_btn = gr.Button("Generate")
161
+
162
+ with gr.Column():
163
+ output_image = gr.Image(label="Generated Image")
164
+ download_btn = gr.File(label="Download full-resolution")
165
+
166
+ inputs = [prompt, image_prompt, controlnet_image, width, height, guidance,
167
+ num_steps, seed, true_gs, ip_scale, neg_ip_scale, neg_prompt,
168
+ neg_image_prompt, timestep_to_start_cfg, control_type, control_weight,
169
+ lora_weight, local_path, lora_local_path, ip_local_path
170
+ ]
171
+ generate_btn.click(
172
+ fn=xflux_pipeline.gradio_generate,
173
+ inputs=inputs,
174
+ outputs=[output_image, download_btn],
175
+ )
176
+
177
+ with gr.Tab("LoRA Finetuning"):
178
+ data_dir = gr.Dropdown(list_train_data_dirs(),
179
+ label="Training images (directory containing the training images)"
180
+ )
181
+ output_dir = gr.Textbox(label="Output Path", value="lora_checkpoint")
182
+
183
+ with gr.Accordion("Training Options", open=True):
184
+ lr = gr.Textbox(label="Learning Rate", value="1e-5")
185
+ steps = gr.Slider(10000, 20000, 20000, step=100, label="Train Steps")
186
+ rank = gr.Slider(1, 100, 16, step=1, label="LoRa Rank")
187
+
188
+ training_btn = gr.Button("Start training")
189
+ training_btn.click(
190
+ fn=start_lora_training,
191
+ inputs=[data_dir, output_dir, lr, steps, rank],
192
+ outputs=[],
193
+ )
194
+
195
+
196
+ return demo
197
+
198
+ if __name__ == "__main__":
199
+ import argparse
200
+ parser = argparse.ArgumentParser(description="Flux")
201
+ parser.add_argument("--name", type=str, default="flux-dev", help="Model name")
202
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use")
203
+ parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
204
+ parser.add_argument("--share", action="store_true", help="Create a public link to your demo")
205
+ parser.add_argument("--ckpt_dir", type=str, default=".", help="Folder with checkpoints in safetensors format")
206
+ args = parser.parse_args()
207
+
208
+ demo = create_demo(args.name, args.device, args.offload, args.ckpt_dir)
209
+ demo.launch(share=args.share)