fffiloni commited on
Commit
8d74f2f
·
verified ·
1 Parent(s): a86116e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -0
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch
 
2
  import gradio as gr
3
  from main import setup, execute_task
4
  from arguments import parse_args
@@ -51,6 +52,7 @@ def clean_dir(save_dir):
51
 
52
  def start_over(gallery_state, loaded_model_setup):
53
  torch.cuda.empty_cache() # Free up cached memory
 
54
  if gallery_state is not None:
55
  gallery_state = None
56
  if loaded_model_setup is not None:
@@ -63,6 +65,7 @@ def setup_model(prompt, model, seed, num_iterations, enable_hps, hps_w, enable_i
63
 
64
  """Clear CUDA memory before starting the training."""
65
  torch.cuda.empty_cache() # Free up cached memory
 
66
 
67
  # Set up arguments
68
  args = parse_args()
@@ -108,6 +111,7 @@ def setup_model(prompt, model, seed, num_iterations, enable_hps, hps_w, enable_i
108
 
109
  def generate_image(setup_args, num_iterations):
110
  torch.cuda.empty_cache() # Free up cached memory
 
111
 
112
  args = setup_args[0]
113
  trainer = setup_args[1]
@@ -125,6 +129,7 @@ def generate_image(setup_args, num_iterations):
125
 
126
  try:
127
  torch.cuda.empty_cache() # Free up cached memory
 
128
  steps_completed = []
129
  result_container = {"best_image": None, "total_init_rewards": None, "total_best_rewards": None}
130
  error_status = {"error_occurred": False} # Shared dictionary to track error status
@@ -175,6 +180,7 @@ def generate_image(setup_args, num_iterations):
175
 
176
  if error_status["error_occurred"]:
177
  torch.cuda.empty_cache() # Free up cached memory
 
178
  yield (None, "CUDA out of memory. Please reduce your batch size or image resolution.", None)
179
  else:
180
  main_thread.join() # Ensure thread completion
@@ -182,13 +188,16 @@ def generate_image(setup_args, num_iterations):
182
  if os.path.exists(final_image_path):
183
  iter_images = list_iter_images(save_dir)
184
  torch.cuda.empty_cache() # Free up cached memory
 
185
  time.sleep(0.5)
186
  yield (final_image_path, f"Final image saved at {final_image_path}", iter_images)
187
  else:
188
  torch.cuda.empty_cache() # Free up cached memory
 
189
  yield (None, "Image generation completed, but no final image was found.", None)
190
 
191
  torch.cuda.empty_cache() # Free up cached memory
 
192
 
193
  except torch.cuda.OutOfMemoryError as e:
194
  print(f"Global CUDA Out of Memory Error: {e}")
 
1
  import torch
2
+ import gc
3
  import gradio as gr
4
  from main import setup, execute_task
5
  from arguments import parse_args
 
52
 
53
  def start_over(gallery_state, loaded_model_setup):
54
  torch.cuda.empty_cache() # Free up cached memory
55
+ gc.collect()
56
  if gallery_state is not None:
57
  gallery_state = None
58
  if loaded_model_setup is not None:
 
65
 
66
  """Clear CUDA memory before starting the training."""
67
  torch.cuda.empty_cache() # Free up cached memory
68
+ gc.collect()
69
 
70
  # Set up arguments
71
  args = parse_args()
 
111
 
112
  def generate_image(setup_args, num_iterations):
113
  torch.cuda.empty_cache() # Free up cached memory
114
+ gc.collect()
115
 
116
  args = setup_args[0]
117
  trainer = setup_args[1]
 
129
 
130
  try:
131
  torch.cuda.empty_cache() # Free up cached memory
132
+ gc.collect()
133
  steps_completed = []
134
  result_container = {"best_image": None, "total_init_rewards": None, "total_best_rewards": None}
135
  error_status = {"error_occurred": False} # Shared dictionary to track error status
 
180
 
181
  if error_status["error_occurred"]:
182
  torch.cuda.empty_cache() # Free up cached memory
183
+ gc.collect()
184
  yield (None, "CUDA out of memory. Please reduce your batch size or image resolution.", None)
185
  else:
186
  main_thread.join() # Ensure thread completion
 
188
  if os.path.exists(final_image_path):
189
  iter_images = list_iter_images(save_dir)
190
  torch.cuda.empty_cache() # Free up cached memory
191
+ gc.collect()
192
  time.sleep(0.5)
193
  yield (final_image_path, f"Final image saved at {final_image_path}", iter_images)
194
  else:
195
  torch.cuda.empty_cache() # Free up cached memory
196
+ gc.collect()
197
  yield (None, "Image generation completed, but no final image was found.", None)
198
 
199
  torch.cuda.empty_cache() # Free up cached memory
200
+ gc.collect()
201
 
202
  except torch.cuda.OutOfMemoryError as e:
203
  print(f"Global CUDA Out of Memory Error: {e}")