fffiloni commited on
Commit
1fcccb4
·
verified ·
1 Parent(s): 9de4480

yield current iterations

Browse files
Files changed (1) hide show
  1. app.py +95 -28
app.py CHANGED
@@ -1,9 +1,15 @@
 
1
  import gradio as gr
2
- from main import main
3
  from arguments import parse_args
4
  import os
5
  import shutil
6
  import glob
 
 
 
 
 
7
 
8
  def list_iter_images(save_dir):
9
  # Specify the image extensions you want to search for
@@ -43,7 +49,16 @@ def clean_dir(save_dir):
43
  else:
44
  print(f"{save_dir} does not exist.")
45
 
46
- def generate_image(prompt, model, num_iterations, learning_rate, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
 
47
  # Set up arguments
48
  args = parse_args()
49
  args.task = "single"
@@ -55,45 +70,83 @@ def generate_image(prompt, model, num_iterations, learning_rate, progress=gr.Pro
55
  args.save_dir = "./outputs"
56
  args.save_all_images = True
57
 
58
- settings = (
59
- f"{args.model}{'_' + args.prompt if args.task == 't2i-compbench' else ''}"
60
- f"{'_no-optim' if args.no_optim else ''}_{args.seed if args.task != 'geneval' else ''}"
61
- f"_lr{args.lr}_gc{args.grad_clip}_iter{args.n_iters}"
62
- f"_reg{args.reg_weight if args.enable_reg else '0'}"
63
- f"{'_pickscore' + str(args.pickscore_weighting) if args.enable_pickscore else ''}"
64
- f"{'_clip' + str(args.clip_weighting) if args.enable_clip else ''}"
65
- f"{'_hps' + str(args.hps_weighting) if args.enable_hps else ''}"
66
- f"{'_imagereward' + str(args.imagereward_weighting) if args.enable_imagereward else ''}"
67
- f"{'_aesthetic' + str(args.aesthetic_weighting) if args.enable_aesthetic else ''}"
68
- )
 
 
 
 
69
 
70
  save_dir = f"{args.save_dir}/{args.task}/{settings}/{args.prompt}"
71
  clean_dir(save_dir)
72
 
73
  try:
74
- # Run the main function with progress tracking
 
 
 
75
  def progress_callback(step):
76
- progress(step / num_iterations, f"Iteration {step}/{num_iterations}")
77
 
78
- best_image, total_init_rewards, total_best_rewards = main(args, progress_callback)
79
-
80
- # Return the path to the generated image
81
- image_path = f"{save_dir}/best_image.png"
82
 
83
- if os.path.exists(image_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  iter_images = list_iter_images(save_dir)
85
- return image_path, f"Image generated successfully and saved at {image_path}", iter_images
86
  else:
87
- return None, "Image generation completed, but the file was not found.", None
88
-
89
  except Exception as e:
90
- return None, f"An error occurred: {str(e)}", None
 
 
 
 
 
 
91
 
92
  # Create Gradio interface
93
  title="# ReNO: Enhancing One-step Text-to-Image Models through Reward-based Noise Optimization"
94
  description="Enter a prompt to generate an image using ReNO. Adjust the model and parameters as needed."
95
 
96
  with gr.Blocks() as demo:
 
 
97
  with gr.Column():
98
  gr.Markdown(title)
99
  gr.Markdown(description)
@@ -111,7 +164,9 @@ with gr.Blocks() as demo:
111
  with gr.Row():
112
  with gr.Column():
113
  prompt = gr.Textbox(label="Prompt")
114
- chosen_model = gr.Dropdown(["sd-turbo", "sdxl-turbo", "pixart", "hyper-sd"], label="Model", value="sd-turbo")
 
 
115
 
116
  with gr.Row():
117
  n_iter = gr.Slider(minimum=10, maximum=100, step=10, value=50, label="Number of Iterations")
@@ -134,12 +189,24 @@ with gr.Blocks() as demo:
134
  with gr.Column():
135
  output_image = gr.Image(type="filepath", label="Best Generated Image")
136
  status = gr.Textbox(label="Status")
137
- iter_gallery = gr.Gallery(label="Iterations", columns=4)
138
 
139
  submit_btn.click(
140
- fn = generate_image,
 
 
 
 
141
  inputs = [prompt, chosen_model, n_iter, learning_rate],
142
- outputs = [output_image, status, iter_gallery]
 
 
 
 
 
 
 
 
143
  )
144
 
145
  # Launch the app
 
1
+ import torch
2
  import gradio as gr
3
+ from main import setup, execute_task
4
  from arguments import parse_args
5
  import os
6
  import shutil
7
  import glob
8
+ import time
9
+ import threading
10
+ import argparse
11
+
12
+
13
 
14
  def list_iter_images(save_dir):
15
  # Specify the image extensions you want to search for
 
49
  else:
50
  print(f"{save_dir} does not exist.")
51
 
52
+ def start_over(gallery_state):
53
+ if gallery_state is not None:
54
+ gallery_state = None
55
+ return gallery_state, None, None, gr.update(visible=False)
56
+
57
+ def setup_model(prompt, model, num_iterations, learning_rate, progress=gr.Progress(track_tqdm=True)):
58
+
59
+ """Clear CUDA memory before starting the training."""
60
+ torch.cuda.empty_cache() # Free up cached memory
61
+
62
  # Set up arguments
63
  args = parse_args()
64
  args.task = "single"
 
70
  args.save_dir = "./outputs"
71
  args.save_all_images = True
72
 
73
+ args, trainer, device, dtype, shape, enable_grad, settings = setup(args)
74
+ loaded_setup = [args, trainer, device, dtype, shape, enable_grad, settings]
75
+
76
+ return None, loaded_setup
77
+
78
+ def generate_image(setup_args, num_iterations):
79
+
80
+ args = setup_args[0]
81
+ trainer = setup_args[1]
82
+ device = setup_args[2]
83
+ dtype = setup_args[3]
84
+ shape = setup_args[4]
85
+ enable_grad = setup_args[5]
86
+
87
+ settings = setup_args[6]
88
 
89
  save_dir = f"{args.save_dir}/{args.task}/{settings}/{args.prompt}"
90
  clean_dir(save_dir)
91
 
92
  try:
93
+ steps_completed = []
94
+ result_container = {"best_image": None, "total_init_rewards": None, "total_best_rewards": None}
95
+
96
+ # Define progress_callback that updates steps_completed
97
  def progress_callback(step):
98
+ steps_completed.append(step)
99
 
100
+ # Function to run main in a separate thread
101
+ def run_main():
102
+ result_container["best_image"], result_container["total_init_rewards"], result_container["total_best_rewards"] = execute_task(args, trainer, device, dtype, shape, enable_grad, settings, progress_callback)
 
103
 
104
+ # Start main in a separate thread
105
+ main_thread = threading.Thread(target=run_main)
106
+ main_thread.start()
107
+
108
+ last_step_yielded = 0
109
+ while main_thread.is_alive() or last_step_yielded < num_iterations:
110
+ # Check if new steps have been completed
111
+ if steps_completed and steps_completed[-1] > last_step_yielded:
112
+ last_step_yielded = steps_completed[-1]
113
+ png_number = last_step_yielded - 1
114
+ # Get the image for this step
115
+ image_path = os.path.join(save_dir, f"{png_number}.png")
116
+ if os.path.exists(image_path):
117
+ yield (image_path, f"Iteration {last_step_yielded}/{num_iterations} - Image saved", None)
118
+ else:
119
+ yield (None, f"Iteration {last_step_yielded}/{num_iterations} - Image not found", None)
120
+ else:
121
+ # Small sleep to prevent busy waiting
122
+ time.sleep(0.1)
123
+
124
+ main_thread.join()
125
+
126
+ # After main is complete, yield the final image
127
+ final_image_path = os.path.join(save_dir, "best_image.png")
128
+ if os.path.exists(final_image_path):
129
  iter_images = list_iter_images(save_dir)
130
+ yield (final_image_path, f"Final image saved at {final_image_path}", iter_images)
131
  else:
132
+ yield (None, "Image generation completed, but no final image was found.", None)
133
+
134
  except Exception as e:
135
+ yield (None, f"An error occurred: {str(e)}", None)
136
+
137
+ def show_gallery_output(gallery_state):
138
+ if gallery_state is not None:
139
+ return gr.update(value=gallery_state, visible=True)
140
+ else:
141
+ return gr.update(value=None, visible=False)
142
 
143
  # Create Gradio interface
144
  title="# ReNO: Enhancing One-step Text-to-Image Models through Reward-based Noise Optimization"
145
  description="Enter a prompt to generate an image using ReNO. Adjust the model and parameters as needed."
146
 
147
  with gr.Blocks() as demo:
148
+ loaded_model_setup = gr.State()
149
+ gallery_state = gr.State()
150
  with gr.Column():
151
  gr.Markdown(title)
152
  gr.Markdown(description)
 
164
  with gr.Row():
165
  with gr.Column():
166
  prompt = gr.Textbox(label="Prompt")
167
+ with gr.Row():
168
+ chosen_model = gr.Dropdown(["sd-turbo", "sdxl-turbo", "pixart", "hyper-sd"], label="Model", value="sd-turbo")
169
+ model_status = gr.Textbox(label="model status", visible=False)
170
 
171
  with gr.Row():
172
  n_iter = gr.Slider(minimum=10, maximum=100, step=10, value=50, label="Number of Iterations")
 
189
  with gr.Column():
190
  output_image = gr.Image(type="filepath", label="Best Generated Image")
191
  status = gr.Textbox(label="Status")
192
+ iter_gallery = gr.Gallery(label="Iterations", columns=4, visible=False)
193
 
194
  submit_btn.click(
195
+ fn = start_over,
196
+ inputs =[gallery_state],
197
+ outputs = [gallery_state, output_image, status, iter_gallery]
198
+ ).then(
199
+ fn = setup_model,
200
  inputs = [prompt, chosen_model, n_iter, learning_rate],
201
+ outputs = [output_image, loaded_model_setup]
202
+ ).then(
203
+ fn = generate_image,
204
+ inputs = [loaded_model_setup, n_iter],
205
+ outputs = [output_image, status, gallery_state]
206
+ ).then(
207
+ fn = show_gallery_output,
208
+ inputs = [gallery_state],
209
+ outputs = iter_gallery
210
  )
211
 
212
  # Launch the app