yiren98 commited on
Commit
21599ad
·
1 Parent(s): 536110b

modified: gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +60 -31
gradio_app.py CHANGED
@@ -94,33 +94,35 @@ def load_target_model(selected_model):
94
  AE_PATH = download_file(ae_repo_id, ae_file)
95
  LORA_WEIGHTS_PATH = download_file(lora_repo, lora_file)
96
 
97
- logger.info("Loading models...")
98
- try:
99
- if model is None is None or clip_l is None or t5xxl is None or ae is None:
100
- _, model = flux_utils.load_flow_model(
101
- BASE_FLUX_CHECKPOINT, torch.float8_e4m3fn, "cpu", disable_mmap=False
102
- )
103
- clip_l = flux_utils.load_clip_l(CLIP_L_PATH, torch.bfloat16, "cpu", disable_mmap=False)
104
- clip_l.eval()
105
- t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, torch.bfloat16, "cpu", disable_mmap=False)
106
- t5xxl.eval()
107
- ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cpu", disable_mmap=False)
108
-
109
- # Load LoRA weights
110
- multiplier = 1.0
111
- weights_sd = load_file(LORA_WEIGHTS_PATH)
112
- lora_model, _ = lora_flux.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True)
113
- lora_model.apply_to([clip_l, t5xxl], model)
114
- info = lora_model.load_state_dict(weights_sd, strict=True)
115
- logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}")
116
- lora_model.eval()
117
-
118
- logger.info("Models loaded successfully.")
119
- return "Models loaded successfully. Using Recraft: {}".format(selected_model)
120
-
121
- except Exception as e:
122
- logger.error(f"Error loading models: {e}")
123
- return f"Error loading models: {e}"
 
 
124
 
125
  # Image pre-processing (resize and padding)
126
  class ResizeWithPadding:
@@ -154,10 +156,37 @@ class ResizeWithPadding:
154
  # The function to generate image from a prompt and conditional image
155
  @spaces.GPU(duration=180)
156
  def infer(prompt, sample_image, recraft_model, seed=0):
157
- global model, clip_l, t5xxl, ae, lora_model
158
- if model is None or lora_model is None or clip_l is None or t5xxl is None or ae is None:
159
- logger.error("Models not loaded. Please load the models first.")
160
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
  model_path = model_paths[recraft_model]
163
  frame_num = model_path['Frame']
 
94
  AE_PATH = download_file(ae_repo_id, ae_file)
95
  LORA_WEIGHTS_PATH = download_file(lora_repo, lora_file)
96
 
97
+ return "Models loaded successfully. Using Recraft: {}".format(selected_model)
98
+
99
+ # logger.info("Loading models...")
100
+ # try:
101
+ # if model is None is None or clip_l is None or t5xxl is None or ae is None:
102
+ # _, model = flux_utils.load_flow_model(
103
+ # BASE_FLUX_CHECKPOINT, torch.float8_e4m3fn, "cpu", disable_mmap=False
104
+ # )
105
+ # clip_l = flux_utils.load_clip_l(CLIP_L_PATH, torch.bfloat16, "cpu", disable_mmap=False)
106
+ # clip_l.eval()
107
+ # t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, torch.bfloat16, "cpu", disable_mmap=False)
108
+ # t5xxl.eval()
109
+ # ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cpu", disable_mmap=False)
110
+
111
+ # # Load LoRA weights
112
+ # multiplier = 1.0
113
+ # weights_sd = load_file(LORA_WEIGHTS_PATH)
114
+ # lora_model, _ = lora_flux.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True)
115
+ # lora_model.apply_to([clip_l, t5xxl], model)
116
+ # info = lora_model.load_state_dict(weights_sd, strict=True)
117
+ # logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}")
118
+ # lora_model.eval()
119
+
120
+ # logger.info("Models loaded successfully.")
121
+ # return "Models loaded successfully. Using Recraft: {}".format(selected_model)
122
+
123
+ # except Exception as e:
124
+ # logger.error(f"Error loading models: {e}")
125
+ # return f"Error loading models: {e}"
126
 
127
  # Image pre-processing (resize and padding)
128
  class ResizeWithPadding:
 
156
  # The function to generate image from a prompt and conditional image
157
  @spaces.GPU(duration=180)
158
  def infer(prompt, sample_image, recraft_model, seed=0):
159
+ # global model, clip_l, t5xxl, ae, lora_model
160
+ # if model is None or lora_model is None or clip_l is None or t5xxl is None or ae is None:
161
+ # logger.error("Models not loaded. Please load the models first.")
162
+ # return None
163
+ logger.info("Loading models...")
164
+ try:
165
+ if model is None is None or clip_l is None or t5xxl is None or ae is None:
166
+ _, model = flux_utils.load_flow_model(
167
+ BASE_FLUX_CHECKPOINT, torch.float8_e4m3fn, "cuda", disable_mmap=False
168
+ )
169
+ clip_l = flux_utils.load_clip_l(CLIP_L_PATH, torch.bfloat16, "cuda", disable_mmap=False)
170
+ clip_l.eval()
171
+ t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, torch.bfloat16, "cuda", disable_mmap=False)
172
+ t5xxl.eval()
173
+ ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cuda", disable_mmap=False)
174
+
175
+ # Load LoRA weights
176
+ multiplier = 1.0
177
+ weights_sd = load_file(LORA_WEIGHTS_PATH)
178
+ lora_model, _ = lora_flux.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True)
179
+ lora_model.apply_to([clip_l, t5xxl], model)
180
+ info = lora_model.load_state_dict(weights_sd, strict=True)
181
+ logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}")
182
+ lora_model.eval()
183
+
184
+ logger.info("Models loaded successfully.")
185
+ # return "Models loaded successfully. Using Recraft: {}".format(selected_model)
186
+
187
+ except Exception as e:
188
+ logger.error(f"Error loading models: {e}")
189
+ return f"Error loading models: {e}"
190
 
191
  model_path = model_paths[recraft_model]
192
  frame_num = model_path['Frame']