yiren98 commited on
Commit
bf3710c
·
1 Parent(s): 7b123f5
Files changed (1) hide show
  1. gradio_app.py +24 -21
gradio_app.py CHANGED
@@ -100,7 +100,6 @@ def load_target_model(selected_model):
100
  t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, torch.bfloat16, "cpu", disable_mmap=False)
101
  t5xxl.eval()
102
  ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cpu", disable_mmap=False)
103
- logger.info("Models loaded successfully.")
104
 
105
  # Load LoRA weights
106
  multiplier = 1.0
@@ -111,6 +110,8 @@ def load_target_model(selected_model):
111
  logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}")
112
  lora_model.eval()
113
 
 
 
114
  except Exception as e:
115
  logger.error(f"Error loading models: {e}")
116
  raise
@@ -129,19 +130,19 @@ class ResizeWithPadding:
129
 
130
  width, height = img.size
131
 
132
- # Convert to RGB to remove transparency, fill with white background if necessary
133
- if img.mode in ('RGBA', 'LA') or (img.mode == 'P' and 'transparency' in img.info):
134
- background = Image.new("RGB", img.size, (fill, fill, fill))
135
- background.paste(img, mask=img.split()[-1]) # Use alpha channel as mask
136
- img = background
137
-
138
- if width == height:
139
- img = img.resize((self.size, self.size), Image.LANCZOS)
140
- else:
141
- max_dim = max(width, height)
142
- new_img = Image.new("RGB", (max_dim, max_dim), (self.fill, self.fill, self.fill))
143
- new_img.paste(img, ((max_dim - width) // 2, (max_dim - height) // 2))
144
- img = new_img.resize((self.size, self.size), Image.LANCZOS)
145
  return img
146
 
147
  # The function to generate image from a prompt and conditional image
@@ -197,9 +198,11 @@ def infer(prompt, sample_image, frame_num, seed=0):
197
  logger.debug("Image encoded to latents.")
198
 
199
  conditions = {}
200
- conditions[prompt] = latents.to("cpu")
 
 
201
 
202
- ae.to("cpu")
203
  clip_l.to(device)
204
  t5xxl.to(device)
205
 
@@ -236,8 +239,8 @@ def infer(prompt, sample_image, frame_num, seed=0):
236
  args = lambda: None
237
  args.frame_num = frame_num
238
 
239
- clip_l.to("cpu")
240
- t5xxl.to("cpu")
241
 
242
  model.to(device)
243
 
@@ -251,12 +254,12 @@ def infer(prompt, sample_image, frame_num, seed=0):
251
  # Decode the final image
252
  x = x.float()
253
  x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
254
- model.to("cpu")
255
  ae.to(device)
256
  with accelerator.autocast(), torch.no_grad():
257
  x = ae.decode(x)
258
  logger.debug("Latents decoded into image.")
259
- ae.to("cpu")
260
 
261
  # Convert the tensor to an image
262
  x = x.clamp(-1, 1)
@@ -285,7 +288,7 @@ with gr.Blocks() as demo:
285
  sample_image = gr.Image(label="Upload a Conditional Image", type="pil")
286
 
287
  # Frame number selection
288
- frame_num = gr.Radio([4, 9], label="Select Frame Number", value=4)
289
 
290
  # Seed
291
  seed = gr.Slider(0, np.iinfo(np.int32).max, step=1, label="Seed", value=0)
 
100
  t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, torch.bfloat16, "cpu", disable_mmap=False)
101
  t5xxl.eval()
102
  ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cpu", disable_mmap=False)
 
103
 
104
  # Load LoRA weights
105
  multiplier = 1.0
 
110
  logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}")
111
  lora_model.eval()
112
 
113
+ logger.info("Models loaded successfully.")
114
+
115
  except Exception as e:
116
  logger.error(f"Error loading models: {e}")
117
  raise
 
130
 
131
  width, height = img.size
132
 
133
+ # # Convert to RGB to remove transparency, fill with white background if necessary
134
+ # if img.mode in ('RGBA', 'LA') or (img.mode == 'P' and 'transparency' in img.info):
135
+ # background = Image.new("RGB", img.size, (fill, fill, fill))
136
+ # background.paste(img, mask=img.split()[-1]) # Use alpha channel as mask
137
+ # img = background
138
+
139
+ # if width == height:
140
+ # img = img.resize((self.size, self.size), Image.LANCZOS)
141
+ # else:
142
+ max_dim = max(width, height)
143
+ new_img = Image.new("RGB", (max_dim, max_dim), (self.fill, self.fill, self.fill))
144
+ new_img.paste(img, ((max_dim - width) // 2, (max_dim - height) // 2))
145
+ img = new_img.resize((self.size, self.size), Image.LANCZOS)
146
  return img
147
 
148
  # The function to generate image from a prompt and conditional image
 
198
  logger.debug("Image encoded to latents.")
199
 
200
  conditions = {}
201
+ # conditions[prompt] = latents.to("cpu")
202
+ conditions[prompt] = latents
203
+
204
 
205
+ # ae.to("cpu")
206
  clip_l.to(device)
207
  t5xxl.to(device)
208
 
 
239
  args = lambda: None
240
  args.frame_num = frame_num
241
 
242
+ # clip_l.to("cpu")
243
+ # t5xxl.to("cpu")
244
 
245
  model.to(device)
246
 
 
254
  # Decode the final image
255
  x = x.float()
256
  x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
257
+ # model.to("cpu")
258
  ae.to(device)
259
  with accelerator.autocast(), torch.no_grad():
260
  x = ae.decode(x)
261
  logger.debug("Latents decoded into image.")
262
+ # ae.to("cpu")
263
 
264
  # Convert the tensor to an image
265
  x = x.clamp(-1, 1)
 
288
  sample_image = gr.Image(label="Upload a Conditional Image", type="pil")
289
 
290
  # Frame number selection
291
+ frame_num = gr.Radio([4, 9], label="Select Frame Number", value=9)
292
 
293
  # Seed
294
  seed = gr.Slider(0, np.iinfo(np.int32).max, step=1, label="Seed", value=0)