yiren98 commited on
Commit
d1bd23e
·
1 Parent(s): 373b0c2
Files changed (1) hide show
  1. gradio_app.py +111 -87
gradio_app.py CHANGED
@@ -12,61 +12,84 @@ from safetensors.torch import load_file
12
  from networks import lora_flux
13
  from library import flux_utils, flux_train_utils_recraft as flux_train_utils, strategy_flux
14
  import logging
 
 
 
15
 
16
  # Set up logger
17
  logger = logging.getLogger(__name__)
18
  logging.basicConfig(level=logging.DEBUG)
19
 
20
- # Ensure necessary devices are available
21
- print("torch.__version__: ", torch.__version__)
22
- device = "cuda" if torch.cuda.is_available() else "cpu"
23
- logger.info("device: {}".format(device))
24
-
25
  accelerator = Accelerator(mixed_precision='bf16', device_placement=True)
26
 
27
- # Model paths (replace these with your actual model paths)
28
- BASE_FLUX_CHECKPOINT="/tiamat-NAS/songyiren/FYP/liucheng/sd-scripts/MergeModel/6_Portrait/6_Portrait.safetensors"
29
- LORA_WEIGHTS_PATH="/tiamat-NAS/songyiren/FYP/liucheng/sd-scripts/RecraftModel/6_Portrait/6_Portrait-step00025000.safetensors"
30
- CLIP_L_PATH="/tiamat-NAS/hailong/storage_backup/models/stabilityai/stable-diffusion-3-medium/text_encoders/clip_l.safetensors"
31
- T5XXL_PATH="/tiamat-NAS/hailong/storage_backup/models/stabilityai/stable-diffusion-3-medium/text_encoders/t5xxl_fp16.safetensors"
32
- AE_PATH="/tiamat-vePFS/share_data/storage/huggingface/models/black-forest-labs/FLUX.1-dev/ae.safetensors"
33
-
34
- from huggingface_hub import login
35
  hf_token = os.getenv("HF_TOKEN")
36
  login(token=hf_token)
37
 
38
- from huggingface_hub import hf_hub_download
39
- repo_id = "Kijai/flux-fp8"
40
- file_name = "flux1-dev-fp8.safetensors"
41
- BASE_FLUX_CHECKPOINT = hf_hub_download(repo_id=repo_id, filename=file_name)
42
- print(f"Downloaded BASE_FLUX_CHECKPOINT saved at: {BASE_FLUX_CHECKPOINT}")
43
-
44
- from huggingface_hub import hf_hub_download
45
- repo_id = "comfyanonymous/flux_text_encoders"
46
- file_name = "clip_l.safetensors"
47
- CLIP_L_PATH = hf_hub_download(repo_id=repo_id, filename=file_name)
48
- print(f"Downloaded CLIP_L_PATH saved at: {CLIP_L_PATH}")
49
-
50
- from huggingface_hub import hf_hub_download
51
- repo_id = "comfyanonymous/flux_text_encoders"
52
- file_name = "t5xxl_fp8_e4m3fn.safetensors"
53
- T5XXL_PATH = hf_hub_download(repo_id=repo_id, filename=file_name)
54
- print(f"Downloaded T5XXL_PATH saved at: {T5XXL_PATH}")
55
-
56
- from huggingface_hub import hf_hub_download
57
- repo_id = "black-forest-labs/FLUX.1-dev"
58
- file_name = "ae.safetensors"
59
- AE_PATH = hf_hub_download(repo_id=repo_id, filename=file_name)
60
- print(f"Downloaded AE_PATH saved at: {AE_PATH}")
61
-
62
- from huggingface_hub import hf_hub_download
63
- repo_id = "showlab/makeanything"
64
- file_name = "recraft/recraft_9f_lego.safetensors"
65
- LORA_WEIGHTS_PATH = hf_hub_download(repo_id=repo_id, filename=file_name)
66
- print(f"Downloaded LORA_WEIGHTS_PATH saved at: {LORA_WEIGHTS_PATH}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  # Load model function
69
- def load_target_model():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  logger.info("Loading models...")
71
  try:
72
  _, model = flux_utils.load_flow_model(
@@ -108,32 +131,36 @@ class ResizeWithPadding:
108
 
109
  # The function to generate image from a prompt and conditional image
110
  @spaces.GPU(duration=180)
111
- def infer(prompt, sample_image, frame_num, seed=0, randomize_seed=False):
 
 
 
 
 
112
  logger.info(f"Started generating image with prompt: {prompt}")
113
-
114
- # Load models
115
- model, [clip_l, t5xxl], ae = load_target_model()
116
 
 
 
117
  model.eval()
118
  clip_l.eval()
119
  t5xxl.eval()
120
  ae.eval()
121
 
122
- # LoRA
123
- multiplier = 1.0
124
- weights_sd = load_file(LORA_WEIGHTS_PATH)
125
- lora_model, _ = lora_flux.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd,
126
- True)
127
-
128
- lora_model.apply_to([clip_l, t5xxl], model)
129
- info = lora_model.load_state_dict(weights_sd, strict=True)
130
- logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}")
131
- lora_model.eval()
132
- lora_model.to(device)
133
-
134
- # Process the seed
135
- if randomize_seed:
136
- seed = random.randint(0, np.iinfo(np.int32).max)
137
  logger.debug(f"Using seed: {seed}")
138
 
139
  # Preprocess the conditional image
@@ -199,9 +226,6 @@ def infer(prompt, sample_image, frame_num, seed=0, randomize_seed=False):
199
 
200
  model.to(device)
201
 
202
- # import pdb
203
- # pdb.set_trace()
204
-
205
  # Run the denoising process
206
  with accelerator.autocast(), torch.no_grad():
207
  x = flux_train_utils.denoise(
@@ -232,39 +256,39 @@ with gr.Blocks() as demo:
232
  gr.Markdown("## FLUX Image Generation")
233
 
234
  with gr.Row():
 
 
 
 
 
 
 
235
  # Input for the prompt
236
  prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here", lines=1)
237
-
238
  # File upload for image
239
  sample_image = gr.Image(label="Upload a Conditional Image", type="pil")
240
-
241
  # Frame number selection
242
- frame_num = gr.Radio([4, 9], label="Select Frame Number", value=9)
243
-
244
- # Seed and randomize seed options
245
  seed = gr.Slider(0, np.iinfo(np.int32).max, step=1, label="Seed", value=0)
246
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
247
-
 
 
248
  # Run Button
249
  run_button = gr.Button("Generate Image")
250
-
251
  # Output result
252
  result_image = gr.Image(label="Generated Image")
253
 
254
- run_button.click(
255
- fn=infer,
256
- inputs=[prompt, sample_image, frame_num, seed, randomize_seed],
257
- outputs=[result_image]
258
- )
259
 
260
  # Launch the Gradio app
261
  demo.launch()
262
-
263
-
264
- # prompt = "1girl"
265
- # sample_image = Image.open("/tiamat-NAS/songyiren/FYP/liucheng/sd-scripts/MergeModel/test/1.png") # 使用一个测试图像
266
- # frame_num = 9
267
- # seed = 42
268
- # randomize_seed = False
269
- # result = infer(prompt, sample_image, frame_num, seed, randomize_seed)
270
- # result.save('asy_results/generated_image.png')
 
12
  from networks import lora_flux
13
  from library import flux_utils, flux_train_utils_recraft as flux_train_utils, strategy_flux
14
  import logging
15
+ from huggingface_hub import login
16
+ from huggingface_hub import hf_hub_download
17
+
18
 
19
  # Set up logger
20
  logger = logging.getLogger(__name__)
21
  logging.basicConfig(level=logging.DEBUG)
22
 
 
 
 
 
 
23
  accelerator = Accelerator(mixed_precision='bf16', device_placement=True)
24
 
 
 
 
 
 
 
 
 
25
  hf_token = os.getenv("HF_TOKEN")
26
  login(token=hf_token)
27
 
28
+ # Model paths dynamically retrieved using selected model
29
+ model_paths = {
30
+ 'Wood Sculpture': {
31
+ 'BASE_FLUX_CHECKPOINT': "Kijai/flux-fp8",
32
+ 'BASE_FILE': "flux1-dev-fp8.safetensors",
33
+ 'LORA_REPO': "showlab/makeanything",
34
+ 'LORA_FILE': "recraft/recraft_4f_wood_sculpture.safetensors"
35
+ },
36
+ 'LEGO': {
37
+ 'BASE_FLUX_CHECKPOINT': "Kijai/flux-fp8",
38
+ 'BASE_FILE': "flux1-dev-fp8.safetensors",
39
+ 'LORA_REPO': "showlab/makeanything",
40
+ 'LORA_FILE': "recraft/recraft_9f_lego.safetensors"
41
+ },
42
+ 'Sketch': {
43
+ 'BASE_FLUX_CHECKPOINT': "Kijai/flux-fp8",
44
+ 'BASE_FILE': "flux1-dev-fp8.safetensors",
45
+ 'LORA_REPO': "showlab/makeanything",
46
+ 'LORA_FILE': "recraft/recraft_9f_sketch.safetensors"
47
+ },
48
+ 'Portrait': {
49
+ 'BASE_FLUX_CHECKPOINT': "Kijai/flux-fp8",
50
+ 'BASE_FILE': "flux1-dev-fp8.safetensors",
51
+ 'LORA_REPO': "showlab/makeanything",
52
+ 'LORA_FILE': "recraft/recraft_9f_portrait.safetensors"
53
+ }
54
+ }
55
+
56
+ # Common paths
57
+ clip_repo_id = "comfyanonymous/flux_text_encoders"
58
+ t5xxl_file = "t5xxl_fp8_e4m3fn.safetensors"
59
+ clip_l_file = "clip_l.safetensors"
60
+ ae_repo_id = "black-forest-labs/FLUX.1-dev"
61
+ ae_file = "ae.safetensors"
62
+
63
+
64
+ # Model placeholders
65
+ model = None
66
+ clip_l = None
67
+ t5xxl = None
68
+ ae = None
69
+ lora_model = None
70
+
71
+ # Function to load a file from Hugging Face Hub
72
+ def download_file(repo_id, file_name):
73
+ return hf_hub_download(repo_id=repo_id, filename=file_name)
74
 
75
  # Load model function
76
+ def load_target_model(selected_model):
77
+ global model, clip_l, t5xxl, ae, lora_model
78
+
79
+ # Fetch paths based on the selected model
80
+ model_path = model_paths[selected_model]
81
+ base_checkpoint_repo = model_path['BASE_FLUX_CHECKPOINT']
82
+ base_checkpoint_file = model_path['BASE_FILE']
83
+ lora_repo = model_path['LORA_REPO']
84
+ lora_file = model_path['LORA_FILE']
85
+
86
+ # Download necessary files
87
+ BASE_FLUX_CHECKPOINT = download_file(base_checkpoint_repo, base_checkpoint_file)
88
+ CLIP_L_PATH = download_file(clip_repo_id, clip_l_file)
89
+ T5XXL_PATH = download_file(clip_repo_id, t5xxl_file)
90
+ AE_PATH = download_file(ae_repo_id, ae_file)
91
+ LORA_WEIGHTS_PATH = download_file(lora_repo, lora_file)
92
+
93
  logger.info("Loading models...")
94
  try:
95
  _, model = flux_utils.load_flow_model(
 
131
 
132
  # The function to generate image from a prompt and conditional image
133
  @spaces.GPU(duration=180)
134
+ def infer(prompt, sample_image, frame_num, seed=0):
135
+ global model, clip_l, t5xxl, ae, lora_model
136
+ if model is None or lora_model is None or clip_l is None or t5xxl is None or ae is None:
137
+ logger.error("Models not loaded. Please load the models first.")
138
+ return None
139
+
140
  logger.info(f"Started generating image with prompt: {prompt}")
 
 
 
141
 
142
+ lora_model.to("cuda")
143
+
144
  model.eval()
145
  clip_l.eval()
146
  t5xxl.eval()
147
  ae.eval()
148
 
149
+ # # Load models
150
+ # model, [clip_l, t5xxl], ae = load_target_model()
151
+
152
+ # # LoRA
153
+ # multiplier = 1.0
154
+ # weights_sd = load_file(LORA_WEIGHTS_PATH)
155
+ # lora_model, _ = lora_flux.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd,
156
+ # True)
157
+
158
+ # lora_model.apply_to([clip_l, t5xxl], model)
159
+ # info = lora_model.load_state_dict(weights_sd, strict=True)
160
+ # logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}")
161
+ # lora_model.eval()
162
+ # lora_model.to(device)
163
+
164
  logger.debug(f"Using seed: {seed}")
165
 
166
  # Preprocess the conditional image
 
226
 
227
  model.to(device)
228
 
 
 
 
229
  # Run the denoising process
230
  with accelerator.autocast(), torch.no_grad():
231
  x = flux_train_utils.denoise(
 
256
  gr.Markdown("## FLUX Image Generation")
257
 
258
  with gr.Row():
259
+ # Dropdown for selecting the recraft model
260
+ recraft_model = gr.Dropdown(
261
+ label="Select Recraft Model",
262
+ choices=["Wood Sculpture", "LEGO", "Sketch", "Portrait"],
263
+ value="Wood Sculpture"
264
+ )
265
+
266
  # Input for the prompt
267
  prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here", lines=1)
268
+
269
  # File upload for image
270
  sample_image = gr.Image(label="Upload a Conditional Image", type="pil")
271
+
272
  # Frame number selection
273
+ frame_num = gr.Radio([4, 9], label="Select Frame Number", value=4)
274
+
275
+ # Seed
276
  seed = gr.Slider(0, np.iinfo(np.int32).max, step=1, label="Seed", value=0)
277
+
278
+ # Load Model Button
279
+ load_button = gr.Button("Load Model")
280
+
281
  # Run Button
282
  run_button = gr.Button("Generate Image")
283
+
284
  # Output result
285
  result_image = gr.Image(label="Generated Image")
286
 
287
+ # Load model button action
288
+ load_button.click(fn=load_target_model, inputs=[recraft_model], outputs=[])
289
+
290
+ # Run Button
291
+ run_button.click(fn=infer, inputs=[prompt, sample_image, frame_num, seed], outputs=[result_image])
292
 
293
  # Launch the Gradio app
294
  demo.launch()