alexnasa commited on
Commit
e297a71
·
verified ·
1 Parent(s): f6e8319

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +175 -53
app.py CHANGED
@@ -1,89 +1,137 @@
 
 
1
  import os
2
  import shutil
3
- import subprocess
4
  from pathlib import Path
5
- from PIL import Image
6
- import gradio as gr
7
  import spaces
8
 
 
 
 
 
9
  INPUT_DIR = "samples"
10
  OUTPUT_DIR = "inference_results/coz_vlmprompt"
11
 
 
 
 
 
12
  def resize_and_center_crop(img: Image.Image, size: int) -> Image.Image:
 
 
 
 
13
  w, h = img.size
14
  scale = size / min(w, h)
15
  new_w, new_h = int(w * scale), int(h * scale)
16
  img = img.resize((new_w, new_h), Image.LANCZOS)
 
17
  left = (new_w - size) // 2
18
  top = (new_h - size) // 2
19
  return img.crop((left, top, left + size, top + size))
20
 
 
 
 
 
 
21
  def make_preview_with_boxes(image_path: str, scale_option: str) -> Image.Image:
 
 
 
 
 
 
 
 
 
 
22
  try:
23
  orig = Image.open(image_path).convert("RGB")
24
  except Exception as e:
 
25
  fallback = Image.new("RGB", (512, 512), (200, 200, 200))
26
- from PIL import ImageDraw
27
  draw = ImageDraw.Draw(fallback)
28
  draw.text((20, 20), f"Error:\n{e}", fill="red")
29
  return fallback
30
- base = resize_and_center_crop(orig, 512)
31
- scale_int = int(scale_option.replace("x", ""))
32
- if scale_int == 1: sizes = [512] * 4
33
- else: sizes = [512 // (scale_int * (2 ** i)) for i in range(4)]
34
- from PIL import ImageDraw
 
 
 
 
 
 
 
 
35
  draw = ImageDraw.Draw(base)
 
 
36
  colors = ["red", "lime", "cyan", "yellow"]
37
- width = 3
 
38
  for idx, s in enumerate(sizes):
 
39
  x0 = (512 - s) // 2
40
  y0 = (512 - s) // 2
41
  x1 = x0 + s
42
  y1 = y0 + s
43
- draw.rectangle([(x0, y0), (x1, y1)], outline=colors[idx], width=width)
 
44
  return base
45
 
 
 
 
 
46
  @spaces.GPU(duration=120)
47
- def run_with_upload(uploaded_image_path, upscale_option, session_id=None):
48
  """
49
- Each invocation creates/uses:
50
- - samples/<session_id>/input.png user’s uploaded image
51
- - inference_results/coz_vlmprompt/<session_id>/per-sample/input/*.png inference outputs
 
 
 
52
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  if uploaded_image_path is None:
54
  return []
55
- # 1) Prepare a per-session input directory
56
- print(session_id)
57
- session_folder = os.path.join(INPUT_DIR, str(session_id))
58
- os.makedirs(session_folder, exist_ok=True)
59
-
60
- # 2) Clear only this session’s folder
61
- for fn in os.listdir(session_folder):
62
- full_path = os.path.join(session_folder, fn)
63
- if os.path.isfile(full_path) or os.path.islink(full_path):
64
- os.remove(full_path)
65
- elif os.path.isdir(full_path):
66
- shutil.rmtree(full_path)
67
-
68
- # 3) Save uploaded image to session_folder/input.png
69
  try:
70
  pil_img = Image.open(uploaded_image_path).convert("RGB")
71
- save_path = Path(session_folder) / "input.png"
 
 
 
 
72
  pil_img.save(save_path, format="PNG")
73
  except Exception as e:
74
- print(f"Error: could not save uploaded image: {e}")
75
  return []
76
 
77
- # 4) Define a per-session output directory
78
- session_output_dir = os.path.join(OUTPUT_DIR, str(session_id))
79
- os.makedirs(session_output_dir, exist_ok=True)
80
-
81
- # 5) Build and run the inference command
82
- upscale_value = upscale_option.replace("x", "")
83
  cmd = [
84
  "python", "inference_coz.py",
85
- "-i", session_folder,
86
- "-o", session_output_dir,
87
  "--rec_type", "recursive_multiscale",
88
  "--prompt_type", "vlm",
89
  "--upscale", upscale_value,
@@ -99,23 +147,32 @@ def run_with_upload(uploaded_image_path, upscale_option, session_id=None):
99
  print("Inference failed:", err)
100
  return []
101
 
102
- # 6) Gather output file paths (1.png through 4.png)
103
- per_sample_dir = os.path.join(session_output_dir, "per-sample", "input")
104
- expected_files = [os.path.join(per_sample_dir, f"{i}.png") for i in range(1, 5)]
 
 
105
  for fp in expected_files:
106
  if not os.path.isfile(fp):
107
  print(f"Warning: expected file not found: {fp}")
108
  return []
109
  return expected_files
110
 
 
111
  def get_caption(src_gallery, evt: gr.SelectData):
 
 
 
 
112
  if not src_gallery or not os.path.isfile(src_gallery[evt.index][0]):
113
  return "No caption available."
 
114
  selected_image_path = src_gallery[evt.index][0]
115
  base = os.path.basename(selected_image_path) # e.g. "2.png"
116
  stem = os.path.splitext(base)[0] # e.g. "2"
117
- txt_folder = os.path.join(OUTPUT_DIR, str(evt.index), "per-sample", "input", "txt")
118
  txt_path = os.path.join(txt_folder, f"{int(stem) - 1}.txt")
 
119
  if not os.path.isfile(txt_path):
120
  return f"Caption file not found: {int(stem) - 1}.txt"
121
  try:
@@ -125,6 +182,11 @@ def get_caption(src_gallery, evt: gr.SelectData):
125
  except Exception as e:
126
  return f"Error reading caption: {e}"
127
 
 
 
 
 
 
128
  css = """
129
  #col-container {
130
  margin: 0 auto;
@@ -133,6 +195,7 @@ css = """
133
  """
134
 
135
  with gr.Blocks(css=css) as demo:
 
136
  gr.HTML(
137
  """
138
  <div style="text-align: center;">
@@ -149,39 +212,98 @@ with gr.Blocks(css=css) as demo:
149
  )
150
 
151
  with gr.Column(elem_id="col-container"):
 
152
  with gr.Row():
153
  with gr.Column():
154
- upload_image = gr.Image(label="Upload your input image", type="filepath")
155
- upscale_radio = gr.Radio(choices=["1x", "2x", "4x"], value="2x", show_label=False)
 
 
 
 
 
 
 
 
 
 
 
 
156
  run_button = gr.Button("Chain-of-Zoom it")
157
- preview_with_box = gr.Image(label="Preview (512×512 with centered boxes)", type="pil", interactive=False)
 
 
 
 
 
 
 
158
 
159
  with gr.Column():
160
- output_gallery = gr.Gallery(label="Inference Results", show_label=True, columns=[2], rows=[2])
161
- caption_text = gr.Textbox(label="Caption", lines=4, placeholder="Click on any image above to see its caption here.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
 
 
 
 
 
 
 
 
 
 
163
  upload_image.change(
164
- fn=lambda img_path, scale_opt: make_preview_with_boxes(img_path, scale_opt) if img_path is not None else None,
165
  inputs=[upload_image, upscale_radio],
166
  outputs=[preview_with_box]
167
  )
 
 
168
  upscale_radio.change(
169
- fn=lambda img_path, scale_opt: make_preview_with_boxes(img_path, scale_opt) if img_path is not None else None,
170
  inputs=[upload_image, upscale_radio],
171
  outputs=[preview_with_box]
172
  )
173
 
174
- # Note: gr.State() will pass session_id automatically
 
 
 
175
  run_button.click(
176
  fn=run_with_upload,
177
- inputs=[upload_image, upscale_radio, gr.State()],
178
  outputs=[output_gallery]
179
  )
180
 
 
 
 
 
181
  output_gallery.select(
182
  fn=get_caption,
183
  inputs=[output_gallery],
184
  outputs=[caption_text]
185
  )
186
 
187
- demo.launch(share=True)
 
 
 
 
 
1
+ import gradio as gr
2
+ import subprocess
3
  import os
4
  import shutil
 
5
  from pathlib import Path
6
+ from PIL import Image, ImageDraw
 
7
  import spaces
8
 
9
+ # ------------------------------------------------------------------
10
+ # CONFIGURE THESE PATHS TO MATCH YOUR PROJECT STRUCTURE
11
+ # ------------------------------------------------------------------
12
+
13
  INPUT_DIR = "samples"
14
  OUTPUT_DIR = "inference_results/coz_vlmprompt"
15
 
16
+ # ------------------------------------------------------------------
17
+ # HELPER: Resize & center-crop to 512, preserving aspect ratio
18
+ # ------------------------------------------------------------------
19
+
20
  def resize_and_center_crop(img: Image.Image, size: int) -> Image.Image:
21
+ """
22
+ Resize the input PIL image so that its shorter side == `size`,
23
+ then center-crop to exactly (size x size).
24
+ """
25
  w, h = img.size
26
  scale = size / min(w, h)
27
  new_w, new_h = int(w * scale), int(h * scale)
28
  img = img.resize((new_w, new_h), Image.LANCZOS)
29
+
30
  left = (new_w - size) // 2
31
  top = (new_h - size) // 2
32
  return img.crop((left, top, left + size, top + size))
33
 
34
+
35
+ # ------------------------------------------------------------------
36
+ # HELPER: Draw four concentric, centered rectangles on a 512×512 image
37
+ # ------------------------------------------------------------------
38
+
39
  def make_preview_with_boxes(image_path: str, scale_option: str) -> Image.Image:
40
+ """
41
+ 1) Open the uploaded image from disk.
42
+ 2) Resize & center-crop it to exactly 512×512.
43
+ 3) Depending on scale_option ("1x","2x","4x"), compute four rectangle sizes:
44
+ - "1x": [512, 512, 512, 512]
45
+ - "2x": [256, 128, 64, 32]
46
+ - "4x": [128, 64, 32, 16]
47
+ 4) Draw each of those four rectangles (outline only), all centered.
48
+ 5) Return the modified PIL image.
49
+ """
50
  try:
51
  orig = Image.open(image_path).convert("RGB")
52
  except Exception as e:
53
+ # If something fails, return a plain 512×512 gray image as fallback
54
  fallback = Image.new("RGB", (512, 512), (200, 200, 200))
 
55
  draw = ImageDraw.Draw(fallback)
56
  draw.text((20, 20), f"Error:\n{e}", fill="red")
57
  return fallback
58
+
59
+ # 1. Resize & center-crop to 512×512
60
+ base = resize_and_center_crop(orig, 512) # now `base.size == (512,512)`
61
+
62
+ # 2. Determine the four box sizes
63
+ scale_int = int(scale_option.replace("x", "")) # e.g. "2x" -> 2
64
+ if scale_int == 1:
65
+ sizes = [512, 512, 512, 512]
66
+ else:
67
+ # For scale=2: sizes = [512//2, 512//(2*2), 512//(2*4), 512//(2*8)] -> [256,128,64,32]
68
+ # For scale=4: sizes = [512//4, 512//(4*2), 512//(4*4), 512//(4*8)] -> [128,64,32,16]
69
+ sizes = [512 // (scale_int * (2 ** i)) for i in range(4)]
70
+
71
  draw = ImageDraw.Draw(base)
72
+
73
+ # 3. Outline color cycle (you can change these or use just one color)
74
  colors = ["red", "lime", "cyan", "yellow"]
75
+ width = 3 # thickness of each rectangle’s outline
76
+
77
  for idx, s in enumerate(sizes):
78
+ # Compute top-left corner so that box is centered in 512×512
79
  x0 = (512 - s) // 2
80
  y0 = (512 - s) // 2
81
  x1 = x0 + s
82
  y1 = y0 + s
83
+ draw.rectangle([(x0, y0), (x1, y1)], outline=colors[idx % len(colors)], width=width)
84
+
85
  return base
86
 
87
+
88
+ # ------------------------------------------------------------------
89
+ # HELPER FUNCTIONS FOR INFERENCE & CAPTION (unchanged from your original)
90
+ # ------------------------------------------------------------------
91
  @spaces.GPU(duration=120)
92
+ def run_with_upload(uploaded_image_path, upscale_option):
93
  """
94
+ 1) Clear INPUT_DIR
95
+ 2) Save the uploaded file as input.png in INPUT_DIR
96
+ 3) Read `upscale_option` (e.g. "1x", "2x", "4x") → turn it into "1","2","4"
97
+ 4) Call inference_coz.py with `--upscale <that_value>`
98
+ 5) Return the FOUR output‐PNG file‐paths as a Python list, so that Gradio's Gallery
99
+ can display them.
100
  """
101
+ # ————————————————————————————————————————————————————————————
102
+ # (Copy‐paste exactly your existing code here; no changes needed)
103
+ # ————————————————————————————————————————————————————————————
104
+
105
+ os.makedirs(INPUT_DIR, exist_ok=True)
106
+ for fn in os.listdir(INPUT_DIR):
107
+ full_path = os.path.join(INPUT_DIR, fn)
108
+ try:
109
+ if os.path.isfile(full_path) or os.path.islink(full_path):
110
+ os.remove(full_path)
111
+ elif os.path.isdir(full_path):
112
+ shutil.rmtree(full_path)
113
+ except Exception as e:
114
+ print(f"Warning: could not delete {full_path}: {e}")
115
+
116
  if uploaded_image_path is None:
117
  return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  try:
119
  pil_img = Image.open(uploaded_image_path).convert("RGB")
120
+ except Exception as e:
121
+ print(f"Error: could not open uploaded image: {e}")
122
+ return []
123
+ save_path = Path(INPUT_DIR) / "input.png"
124
+ try:
125
  pil_img.save(save_path, format="PNG")
126
  except Exception as e:
127
+ print(f"Error: could not save as PNG: {e}")
128
  return []
129
 
130
+ upscale_value = upscale_option.replace("x", "") # e.g. "2x" "2"
 
 
 
 
 
131
  cmd = [
132
  "python", "inference_coz.py",
133
+ "-i", INPUT_DIR,
134
+ "-o", OUTPUT_DIR,
135
  "--rec_type", "recursive_multiscale",
136
  "--prompt_type", "vlm",
137
  "--upscale", upscale_value,
 
147
  print("Inference failed:", err)
148
  return []
149
 
150
+ per_sample_dir = os.path.join(OUTPUT_DIR, "per-sample", "input")
151
+ expected_files = [
152
+ os.path.join(per_sample_dir, f"{i}.png")
153
+ for i in range(1, 5)
154
+ ]
155
  for fp in expected_files:
156
  if not os.path.isfile(fp):
157
  print(f"Warning: expected file not found: {fp}")
158
  return []
159
  return expected_files
160
 
161
+
162
  def get_caption(src_gallery, evt: gr.SelectData):
163
+ """
164
+ Given a clicked‐on image in the gallery, read the corresponding .txt in
165
+ .../per-sample/input/txt and return its contents.
166
+ """
167
  if not src_gallery or not os.path.isfile(src_gallery[evt.index][0]):
168
  return "No caption available."
169
+
170
  selected_image_path = src_gallery[evt.index][0]
171
  base = os.path.basename(selected_image_path) # e.g. "2.png"
172
  stem = os.path.splitext(base)[0] # e.g. "2"
173
+ txt_folder = os.path.join(OUTPUT_DIR, "per-sample", "input", "txt")
174
  txt_path = os.path.join(txt_folder, f"{int(stem) - 1}.txt")
175
+
176
  if not os.path.isfile(txt_path):
177
  return f"Caption file not found: {int(stem) - 1}.txt"
178
  try:
 
182
  except Exception as e:
183
  return f"Error reading caption: {e}"
184
 
185
+
186
+ # ------------------------------------------------------------------
187
+ # BUILD THE GRADIO INTERFACE (with updated callbacks)
188
+ # ------------------------------------------------------------------
189
+
190
  css = """
191
  #col-container {
192
  margin: 0 auto;
 
195
  """
196
 
197
  with gr.Blocks(css=css) as demo:
198
+
199
  gr.HTML(
200
  """
201
  <div style="text-align: center;">
 
212
  )
213
 
214
  with gr.Column(elem_id="col-container"):
215
+
216
  with gr.Row():
217
  with gr.Column():
218
+ # 1) Image upload component
219
+ upload_image = gr.Image(
220
+ label="Upload your input image",
221
+ type="filepath"
222
+ )
223
+
224
+ # 2) Radio for choosing 1× / 2× / 4× upscaling
225
+ upscale_radio = gr.Radio(
226
+ choices=["1x", "2x", "4x"],
227
+ value="2x",
228
+ show_label=False
229
+ )
230
+
231
+ # 3) Button to launch inference
232
  run_button = gr.Button("Chain-of-Zoom it")
233
+
234
+ # 4) Show the 512×512 preview with four centered rectangles
235
+ preview_with_box = gr.Image(
236
+ label="Preview (512×512 with centered boxes)",
237
+ type="pil", # we’ll return a PIL.Image from our function
238
+ interactive=False
239
+ )
240
+
241
 
242
  with gr.Column():
243
+ # 5) Gallery to display multiple output images
244
+ output_gallery = gr.Gallery(
245
+ label="Inference Results",
246
+ show_label=True,
247
+ elem_id="gallery",
248
+ columns=[2], rows=[2]
249
+ )
250
+
251
+ # 6) Textbox under the gallery for showing captions
252
+ caption_text = gr.Textbox(
253
+ label="Caption",
254
+ lines=4,
255
+ placeholder="Click on any image above to see its caption here."
256
+ )
257
+
258
+ # ------------------------------------------------------------------
259
+ # CALLBACK #1: Whenever the user uploads or changes the radio, update preview
260
+ # ------------------------------------------------------------------
261
 
262
+ def update_preview(img_path, scale_opt):
263
+ """
264
+ If there's no image uploaded yet, return None (Gradio will show blank).
265
+ Otherwise, draw the resized 512×512 + four boxes and return it.
266
+ """
267
+ if img_path is None:
268
+ return None
269
+ return make_preview_with_boxes(img_path, scale_opt)
270
+
271
+ # When the user uploads a new file:
272
  upload_image.change(
273
+ fn=update_preview,
274
  inputs=[upload_image, upscale_radio],
275
  outputs=[preview_with_box]
276
  )
277
+
278
+ # Also trigger preview redraw if they switch 1×/2×/4× after uploading:
279
  upscale_radio.change(
280
+ fn=update_preview,
281
  inputs=[upload_image, upscale_radio],
282
  outputs=[preview_with_box]
283
  )
284
 
285
+ # ------------------------------------------------------------------
286
+ # CALLBACK #2: When “Chain-of-Zoom it” is clicked, run inference
287
+ # ------------------------------------------------------------------
288
+
289
  run_button.click(
290
  fn=run_with_upload,
291
+ inputs=[upload_image, upscale_radio],
292
  outputs=[output_gallery]
293
  )
294
 
295
+ # ------------------------------------------------------------------
296
+ # CALLBACK #3: When an image in the gallery is clicked, show its caption
297
+ # ------------------------------------------------------------------
298
+
299
  output_gallery.select(
300
  fn=get_caption,
301
  inputs=[output_gallery],
302
  outputs=[caption_text]
303
  )
304
 
305
+ # ------------------------------------------------------------------
306
+ # START THE GRADIO SERVER
307
+ # ------------------------------------------------------------------
308
+
309
+ demo.launch(share=True)