yiren98 commited on
Commit
a73746d
·
1 Parent(s): 592a5ce
Files changed (1) hide show
  1. gradio_app.py +22 -17
gradio_app.py CHANGED
@@ -31,25 +31,29 @@ model_paths = {
31
  'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
32
  'BASE_FILE': "flux_merge_lora/flux_merge_4f_wood-fp8_e4m3fn.safetensors",
33
  'LORA_REPO': "showlab/makeanything",
34
- 'LORA_FILE': "recraft/recraft_4f_wood_sculpture.safetensors"
 
35
  },
36
  'LEGO': {
37
  'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
38
  'BASE_FILE': "flux_merge_lora/flux_merge_9f_lego-fp8_e4m3fn.safetensors",
39
  'LORA_REPO': "showlab/makeanything",
40
- 'LORA_FILE': "recraft/recraft_9f_lego.safetensors"
 
41
  },
42
  'Sketch': {
43
  'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
44
  'BASE_FILE': "flux_merge_lora/flux_merge_9f_portrait-fp8_e4m3fn.safetensors",
45
  'LORA_REPO': "showlab/makeanything",
46
- 'LORA_FILE': "recraft/recraft_9f_sketch.safetensors"
 
47
  },
48
  'Portrait': {
49
  'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
50
  'BASE_FILE': "flux_merge_lora/flux_merge_9f_sketch-fp8_e4m3fn.safetensors",
51
  'LORA_REPO': "showlab/makeanything",
52
- 'LORA_FILE': "recraft/recraft_9f_portrait.safetensors"
 
53
  }
54
  }
55
 
@@ -92,14 +96,15 @@ def load_target_model(selected_model):
92
 
93
  logger.info("Loading models...")
94
  try:
95
- _, model = flux_utils.load_flow_model(
96
- BASE_FLUX_CHECKPOINT, torch.float8_e4m3fn, "cpu", disable_mmap=False
97
- )
98
- clip_l = flux_utils.load_clip_l(CLIP_L_PATH, torch.bfloat16, "cpu", disable_mmap=False)
99
- clip_l.eval()
100
- t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, torch.bfloat16, "cpu", disable_mmap=False)
101
- t5xxl.eval()
102
- ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cpu", disable_mmap=False)
 
103
 
104
  # Load LoRA weights
105
  multiplier = 1.0
@@ -148,12 +153,15 @@ class ResizeWithPadding:
148
 
149
  # The function to generate image from a prompt and conditional image
150
  @spaces.GPU(duration=180)
151
- def infer(prompt, sample_image, frame_num, seed=0):
152
  global model, clip_l, t5xxl, ae, lora_model
153
  if model is None or lora_model is None or clip_l is None or t5xxl is None or ae is None:
154
  logger.error("Models not loaded. Please load the models first.")
155
  return None
156
 
 
 
 
157
  logger.info(f"Started generating image with prompt: {prompt}")
158
 
159
  lora_model.to("cuda")
@@ -288,9 +296,6 @@ with gr.Blocks() as demo:
288
  # File upload for image
289
  sample_image = gr.Image(label="Upload a Conditional Image", type="pil")
290
 
291
- # Frame number selection
292
- frame_num = gr.Radio([4, 9], label="Select Frame Number", value=9)
293
-
294
  # Seed
295
  seed = gr.Slider(0, np.iinfo(np.int32).max, step=1, label="Seed", value=0)
296
 
@@ -310,7 +315,7 @@ with gr.Blocks() as demo:
310
  load_button.click(fn=load_target_model, inputs=[recraft_model], outputs=[status_box])
311
 
312
  # Run Button
313
- run_button.click(fn=infer, inputs=[prompt, sample_image, frame_num, seed], outputs=[result_image])
314
 
315
  # Launch the Gradio app
316
  demo.launch()
 
31
  'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
32
  'BASE_FILE': "flux_merge_lora/flux_merge_4f_wood-fp8_e4m3fn.safetensors",
33
  'LORA_REPO': "showlab/makeanything",
34
+ 'LORA_FILE': "recraft/recraft_4f_wood_sculpture.safetensors",
35
+ "Frame": 4
36
  },
37
  'LEGO': {
38
  'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
39
  'BASE_FILE': "flux_merge_lora/flux_merge_9f_lego-fp8_e4m3fn.safetensors",
40
  'LORA_REPO': "showlab/makeanything",
41
+ 'LORA_FILE': "recraft/recraft_9f_lego.safetensors",
42
+ "Frame": 9
43
  },
44
  'Sketch': {
45
  'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
46
  'BASE_FILE': "flux_merge_lora/flux_merge_9f_portrait-fp8_e4m3fn.safetensors",
47
  'LORA_REPO': "showlab/makeanything",
48
+ 'LORA_FILE': "recraft/recraft_9f_sketch.safetensors",
49
+ "Frame": 9
50
  },
51
  'Portrait': {
52
  'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
53
  'BASE_FILE': "flux_merge_lora/flux_merge_9f_sketch-fp8_e4m3fn.safetensors",
54
  'LORA_REPO': "showlab/makeanything",
55
+ 'LORA_FILE': "recraft/recraft_9f_portrait.safetensors",
56
+ "Frame": 9
57
  }
58
  }
59
 
 
96
 
97
  logger.info("Loading models...")
98
  try:
99
+ if model is None is None or clip_l is None or t5xxl is None or ae is None:
100
+ _, model = flux_utils.load_flow_model(
101
+ BASE_FLUX_CHECKPOINT, torch.float8_e4m3fn, "cpu", disable_mmap=False
102
+ )
103
+ clip_l = flux_utils.load_clip_l(CLIP_L_PATH, torch.bfloat16, "cpu", disable_mmap=False)
104
+ clip_l.eval()
105
+ t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, torch.bfloat16, "cpu", disable_mmap=False)
106
+ t5xxl.eval()
107
+ ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cpu", disable_mmap=False)
108
 
109
  # Load LoRA weights
110
  multiplier = 1.0
 
153
 
154
  # The function to generate image from a prompt and conditional image
155
  @spaces.GPU(duration=180)
156
+ def infer(prompt, sample_image, recraft_model, seed=0):
157
  global model, clip_l, t5xxl, ae, lora_model
158
  if model is None or lora_model is None or clip_l is None or t5xxl is None or ae is None:
159
  logger.error("Models not loaded. Please load the models first.")
160
  return None
161
 
162
+ model_path = model_paths[selected_model]
163
+ frame_num = model_path['Frame']
164
+
165
  logger.info(f"Started generating image with prompt: {prompt}")
166
 
167
  lora_model.to("cuda")
 
296
  # File upload for image
297
  sample_image = gr.Image(label="Upload a Conditional Image", type="pil")
298
 
 
 
 
299
  # Seed
300
  seed = gr.Slider(0, np.iinfo(np.int32).max, step=1, label="Seed", value=0)
301
 
 
315
  load_button.click(fn=load_target_model, inputs=[recraft_model], outputs=[status_box])
316
 
317
  # Run Button
318
+ run_button.click(fn=infer, inputs=[prompt, sample_image, recraft_model, seed], outputs=[result_image])
319
 
320
  # Launch the Gradio app
321
  demo.launch()