yiren98 commited on
Commit
656a974
·
verified ·
1 Parent(s): 04989f6

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +8 -7
gradio_app.py CHANGED
@@ -18,6 +18,8 @@ logging.basicConfig(level=logging.DEBUG)
18
 
19
  # Ensure necessary devices are available
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
21
  accelerator = Accelerator(mixed_precision='bf16', device_placement=True)
22
 
23
  # Model paths (replace these with your actual model paths)
@@ -124,7 +126,7 @@ def infer(prompt, sample_image, frame_num, seed=0, randomize_seed=False):
124
  info = lora_model.load_state_dict(weights_sd, strict=True)
125
  logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}")
126
  lora_model.eval()
127
- lora_model.to("cuda")
128
 
129
  # Process the seed
130
  if randomize_seed:
@@ -145,7 +147,7 @@ def infer(prompt, sample_image, frame_num, seed=0, randomize_seed=False):
145
  logger.debug("Conditional image preprocessed.")
146
 
147
  # Encode the image to latents
148
- ae.to("cuda")
149
  latents = ae.encode(image)
150
  logger.debug("Image encoded to latents.")
151
 
@@ -153,8 +155,8 @@ def infer(prompt, sample_image, frame_num, seed=0, randomize_seed=False):
153
  conditions[prompt] = latents.to("cpu")
154
 
155
  ae.to("cpu")
156
- clip_l.to("cuda")
157
- t5xxl.to("cuda")
158
 
159
  # Encode the prompt
160
  tokenize_strategy = strategy_flux.FluxTokenizeStrategy(512)
@@ -192,8 +194,7 @@ def infer(prompt, sample_image, frame_num, seed=0, randomize_seed=False):
192
  clip_l.to("cpu")
193
  t5xxl.to("cpu")
194
 
195
- torch.cuda.empty_cache()
196
- model.to("cuda")
197
 
198
  # import pdb
199
  # pdb.set_trace()
@@ -209,7 +210,7 @@ def infer(prompt, sample_image, frame_num, seed=0, randomize_seed=False):
209
  x = x.float()
210
  x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
211
  model.to("cpu")
212
- ae.to("cuda")
213
  with accelerator.autocast(), torch.no_grad():
214
  x = ae.decode(x)
215
  logger.debug("Latents decoded into image.")
 
18
 
19
  # Ensure necessary devices are available
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ logger.info("device: ", device)
22
+
23
  accelerator = Accelerator(mixed_precision='bf16', device_placement=True)
24
 
25
  # Model paths (replace these with your actual model paths)
 
126
  info = lora_model.load_state_dict(weights_sd, strict=True)
127
  logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}")
128
  lora_model.eval()
129
+ lora_model.to(device)
130
 
131
  # Process the seed
132
  if randomize_seed:
 
147
  logger.debug("Conditional image preprocessed.")
148
 
149
  # Encode the image to latents
150
+ ae.to(device)
151
  latents = ae.encode(image)
152
  logger.debug("Image encoded to latents.")
153
 
 
155
  conditions[prompt] = latents.to("cpu")
156
 
157
  ae.to("cpu")
158
+ clip_l.to(device)
159
+ t5xxl.to(device)
160
 
161
  # Encode the prompt
162
  tokenize_strategy = strategy_flux.FluxTokenizeStrategy(512)
 
194
  clip_l.to("cpu")
195
  t5xxl.to("cpu")
196
 
197
+ model.to(device)
 
198
 
199
  # import pdb
200
  # pdb.set_trace()
 
210
  x = x.float()
211
  x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
212
  model.to("cpu")
213
+ ae.to(device)
214
  with accelerator.autocast(), torch.no_grad():
215
  x = ae.decode(x)
216
  logger.debug("Latents decoded into image.")