Spaces:
Runtime error
Runtime error
modified: gradio_app.py
Browse files- gradio_app.py +60 -31
gradio_app.py
CHANGED
@@ -94,33 +94,35 @@ def load_target_model(selected_model):
|
|
94 |
AE_PATH = download_file(ae_repo_id, ae_file)
|
95 |
LORA_WEIGHTS_PATH = download_file(lora_repo, lora_file)
|
96 |
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
|
|
|
|
124 |
|
125 |
# Image pre-processing (resize and padding)
|
126 |
class ResizeWithPadding:
|
@@ -154,10 +156,37 @@ class ResizeWithPadding:
|
|
154 |
# The function to generate image from a prompt and conditional image
|
155 |
@spaces.GPU(duration=180)
|
156 |
def infer(prompt, sample_image, recraft_model, seed=0):
|
157 |
-
global model, clip_l, t5xxl, ae, lora_model
|
158 |
-
if model is None or lora_model is None or clip_l is None or t5xxl is None or ae is None:
|
159 |
-
|
160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
|
162 |
model_path = model_paths[recraft_model]
|
163 |
frame_num = model_path['Frame']
|
|
|
94 |
AE_PATH = download_file(ae_repo_id, ae_file)
|
95 |
LORA_WEIGHTS_PATH = download_file(lora_repo, lora_file)
|
96 |
|
97 |
+
return "Models loaded successfully. Using Recraft: {}".format(selected_model)
|
98 |
+
|
99 |
+
# logger.info("Loading models...")
|
100 |
+
# try:
|
101 |
+
# if model is None is None or clip_l is None or t5xxl is None or ae is None:
|
102 |
+
# _, model = flux_utils.load_flow_model(
|
103 |
+
# BASE_FLUX_CHECKPOINT, torch.float8_e4m3fn, "cpu", disable_mmap=False
|
104 |
+
# )
|
105 |
+
# clip_l = flux_utils.load_clip_l(CLIP_L_PATH, torch.bfloat16, "cpu", disable_mmap=False)
|
106 |
+
# clip_l.eval()
|
107 |
+
# t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, torch.bfloat16, "cpu", disable_mmap=False)
|
108 |
+
# t5xxl.eval()
|
109 |
+
# ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cpu", disable_mmap=False)
|
110 |
+
|
111 |
+
# # Load LoRA weights
|
112 |
+
# multiplier = 1.0
|
113 |
+
# weights_sd = load_file(LORA_WEIGHTS_PATH)
|
114 |
+
# lora_model, _ = lora_flux.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True)
|
115 |
+
# lora_model.apply_to([clip_l, t5xxl], model)
|
116 |
+
# info = lora_model.load_state_dict(weights_sd, strict=True)
|
117 |
+
# logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}")
|
118 |
+
# lora_model.eval()
|
119 |
+
|
120 |
+
# logger.info("Models loaded successfully.")
|
121 |
+
# return "Models loaded successfully. Using Recraft: {}".format(selected_model)
|
122 |
+
|
123 |
+
# except Exception as e:
|
124 |
+
# logger.error(f"Error loading models: {e}")
|
125 |
+
# return f"Error loading models: {e}"
|
126 |
|
127 |
# Image pre-processing (resize and padding)
|
128 |
class ResizeWithPadding:
|
|
|
156 |
# The function to generate image from a prompt and conditional image
|
157 |
@spaces.GPU(duration=180)
|
158 |
def infer(prompt, sample_image, recraft_model, seed=0):
|
159 |
+
# global model, clip_l, t5xxl, ae, lora_model
|
160 |
+
# if model is None or lora_model is None or clip_l is None or t5xxl is None or ae is None:
|
161 |
+
# logger.error("Models not loaded. Please load the models first.")
|
162 |
+
# return None
|
163 |
+
logger.info("Loading models...")
|
164 |
+
try:
|
165 |
+
if model is None is None or clip_l is None or t5xxl is None or ae is None:
|
166 |
+
_, model = flux_utils.load_flow_model(
|
167 |
+
BASE_FLUX_CHECKPOINT, torch.float8_e4m3fn, "cuda", disable_mmap=False
|
168 |
+
)
|
169 |
+
clip_l = flux_utils.load_clip_l(CLIP_L_PATH, torch.bfloat16, "cuda", disable_mmap=False)
|
170 |
+
clip_l.eval()
|
171 |
+
t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, torch.bfloat16, "cuda", disable_mmap=False)
|
172 |
+
t5xxl.eval()
|
173 |
+
ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cuda", disable_mmap=False)
|
174 |
+
|
175 |
+
# Load LoRA weights
|
176 |
+
multiplier = 1.0
|
177 |
+
weights_sd = load_file(LORA_WEIGHTS_PATH)
|
178 |
+
lora_model, _ = lora_flux.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True)
|
179 |
+
lora_model.apply_to([clip_l, t5xxl], model)
|
180 |
+
info = lora_model.load_state_dict(weights_sd, strict=True)
|
181 |
+
logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}")
|
182 |
+
lora_model.eval()
|
183 |
+
|
184 |
+
logger.info("Models loaded successfully.")
|
185 |
+
# return "Models loaded successfully. Using Recraft: {}".format(selected_model)
|
186 |
+
|
187 |
+
except Exception as e:
|
188 |
+
logger.error(f"Error loading models: {e}")
|
189 |
+
return f"Error loading models: {e}"
|
190 |
|
191 |
model_path = model_paths[recraft_model]
|
192 |
frame_num = model_path['Frame']
|