yiren98 commited on
Commit
1dd7c81
·
1 Parent(s): 9310c81

modified: gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +10 -1
gradio_app.py CHANGED
@@ -100,7 +100,16 @@ def load_target_model(selected_model):
100
  t5xxl.eval()
101
  ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cpu", disable_mmap=False)
102
  logger.info("Models loaded successfully.")
103
- return model, [clip_l, t5xxl], ae
 
 
 
 
 
 
 
 
 
104
  except Exception as e:
105
  logger.error(f"Error loading models: {e}")
106
  raise
 
100
  t5xxl.eval()
101
  ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cpu", disable_mmap=False)
102
  logger.info("Models loaded successfully.")
103
+
104
+ # Load LoRA weights
105
+ multiplier = 1.0
106
+ weights_sd = load_file(LORA_WEIGHTS_PATH)
107
+ lora_model, _ = lora_flux.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True)
108
+ lora_model.apply_to([clip_l, t5xxl], model)
109
+ info = lora_model.load_state_dict(weights_sd, strict=True)
110
+ logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}")
111
+ lora_model.eval()
112
+
113
  except Exception as e:
114
  logger.error(f"Error loading models: {e}")
115
  raise