alexnasa commited on
Commit
80fbabd
·
verified ·
1 Parent(s): b488f86

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -55
app.py CHANGED
@@ -3,10 +3,12 @@ import subprocess
3
  import os
4
  import shutil
5
  from pathlib import Path
6
- from inference_coz_single import recursive_multiscale_sr
7
- from PIL import Image, ImageDraw
8
  import spaces
9
 
 
 
 
 
10
 
11
  # ------------------------------------------------------------------
12
  # CONFIGURE THESE PATHS TO MATCH YOUR PROJECT STRUCTURE
@@ -15,6 +17,7 @@ import spaces
15
  INPUT_DIR = "samples"
16
  OUTPUT_DIR = "inference_results/coz_vlmprompt"
17
 
 
18
  # ------------------------------------------------------------------
19
  # HELPER: Resize & center-crop to 512, preserving aspect ratio
20
  # ------------------------------------------------------------------
@@ -35,69 +38,141 @@ def resize_and_center_crop(img: Image.Image, size: int) -> Image.Image:
35
 
36
 
37
  # ------------------------------------------------------------------
38
- # HELPER: Draw four concentric, centered rectangles on a 512×512 image
39
  # ------------------------------------------------------------------
40
 
41
- def make_preview_with_boxes(image_path: str, scale_option: str) -> Image.Image:
 
 
 
 
 
42
  """
43
- 1) Open the uploaded image from disk.
44
- 2) Resize & center-crop it to exactly 512×512.
45
- 3) Depending on scale_option ("1x","2x","4x"), compute four rectangle sizes:
46
- - "1x": [512, 512, 512, 512]
47
- - "2x": [256, 128, 64, 32]
48
- - "4x": [128, 64, 32, 16]
49
- 4) Draw each of those four rectangles (outline only), all centered.
50
- 5) Return the modified PIL image.
 
 
 
 
 
 
 
 
 
 
 
51
  """
52
  try:
53
  orig = Image.open(image_path).convert("RGB")
54
  except Exception as e:
55
- # If something fails, return a plain 512×512 gray image as fallback
56
  fallback = Image.new("RGB", (512, 512), (200, 200, 200))
57
  draw = ImageDraw.Draw(fallback)
58
  draw.text((20, 20), f"Error:\n{e}", fill="red")
59
  return fallback
60
 
61
- # 1. Resize & center-crop to 512×512
62
- base = resize_and_center_crop(orig, 512) # now `base.size == (512,512)`
63
 
64
- # 2. Determine the four box sizes
65
- scale_int = int(scale_option.replace("x", "")) # e.g. "2x" -> 2
66
- if scale_int == 1:
 
67
  sizes = [512, 512, 512, 512]
68
  else:
69
- # For scale=2: sizes = [512//2, 512//(2*2), 512//(2*4), 512//(2*8)] -> [256,128,64,32]
70
- # For scale=4: sizes = [512//4, 512//(4*2), 512//(4*4), 512//(4*8)] -> [128,64,32,16]
71
- sizes = [512 // (scale_int * (2 ** i)) for i in range(4)]
 
 
72
 
73
  draw = ImageDraw.Draw(base)
74
-
75
- # 3. Outline color cycle (you can change these or use just one color)
76
  colors = ["red", "lime", "cyan", "yellow"]
77
- width = 3 # thickness of each rectangle’s outline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- for idx, s in enumerate(sizes):
80
- # Compute top-left corner so that box is centered in 512×512
81
- x0 = (512 - s) // 2
82
- y0 = (512 - s) // 2
83
- x1 = x0 + s
84
- y1 = y0 + s
85
- draw.rectangle([(x0, y0), (x1, y1)], outline=colors[idx % len(colors)], width=width)
86
 
87
  return base
88
 
89
 
 
 
 
 
90
  @spaces.GPU(duration=120)
91
- def run_with_upload(uploaded_image_path, upscale_option):
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
- upscale_value = upscale_option.replace("x", "") # e.g. "2x" → "2"
94
-
95
- return recursive_multiscale_sr(uploaded_image_path, int(upscale_value))[0]
96
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
 
99
  # ------------------------------------------------------------------
100
- # BUILD THE GRADIO INTERFACE (with updated callbacks)
101
  # ------------------------------------------------------------------
102
 
103
  css = """
@@ -141,19 +216,35 @@ with gr.Blocks(css=css) as demo:
141
  show_label=False
142
  )
143
 
144
- # 3) Button to launch inference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  run_button = gr.Button("Chain-of-Zoom it")
146
 
147
- # 4) Show the 512×512 preview with four centered rectangles
148
  preview_with_box = gr.Image(
149
- label="Preview (512×512 with centered boxes)",
150
- type="pil", # we’ll return a PIL.Image from our function
151
  interactive=False
152
  )
153
 
154
 
155
  with gr.Column():
156
- # 5) Gallery to display multiple output images
157
  output_gallery = gr.Gallery(
158
  label="Inference Results",
159
  show_label=True,
@@ -162,39 +253,51 @@ with gr.Blocks(css=css) as demo:
162
  )
163
 
164
  # ------------------------------------------------------------------
165
- # CALLBACK #1: Whenever the user uploads or changes the radio, update preview
166
  # ------------------------------------------------------------------
167
 
168
- def update_preview(img_path, scale_opt):
 
 
 
 
 
169
  """
170
- If there's no image uploaded yet, return None (Gradio will show blank).
171
- Otherwise, draw the resized 512×512 + four boxes and return it.
172
  """
173
  if img_path is None:
174
  return None
175
- return make_preview_with_boxes(img_path, scale_opt)
176
 
177
- # When the user uploads a new file:
178
  upload_image.change(
179
  fn=update_preview,
180
- inputs=[upload_image, upscale_radio],
181
  outputs=[preview_with_box]
182
  )
183
-
184
- # Also trigger preview redraw if they switch 1×/2×/4× after uploading:
185
  upscale_radio.change(
186
  fn=update_preview,
187
- inputs=[upload_image, upscale_radio],
 
 
 
 
 
 
 
 
 
 
188
  outputs=[preview_with_box]
189
  )
190
 
191
  # ------------------------------------------------------------------
192
- # CALLBACK #2: When “Chain-of-Zoom it” is clicked, run inference
193
  # ------------------------------------------------------------------
194
 
195
  run_button.click(
196
  fn=run_with_upload,
197
- inputs=[upload_image, upscale_radio],
198
  outputs=[output_gallery]
199
  )
200
 
@@ -203,5 +306,4 @@ with gr.Blocks(css=css) as demo:
203
  # START THE GRADIO SERVER
204
  # ------------------------------------------------------------------
205
 
206
- # 🔧 2) launch as usual
207
- demo.launch(share=True)
 
3
  import os
4
  import shutil
5
  from pathlib import Path
 
 
6
  import spaces
7
 
8
+ # import the updated recursive_multiscale_sr that expects a list of centers
9
+ from inference_coz_single import recursive_multiscale_sr
10
+
11
+ from PIL import Image, ImageDraw
12
 
13
  # ------------------------------------------------------------------
14
  # CONFIGURE THESE PATHS TO MATCH YOUR PROJECT STRUCTURE
 
17
  INPUT_DIR = "samples"
18
  OUTPUT_DIR = "inference_results/coz_vlmprompt"
19
 
20
+
21
  # ------------------------------------------------------------------
22
  # HELPER: Resize & center-crop to 512, preserving aspect ratio
23
  # ------------------------------------------------------------------
 
38
 
39
 
40
  # ------------------------------------------------------------------
41
+ # HELPER: Draw four true “nested” rectangles, matching the SR logic
42
  # ------------------------------------------------------------------
43
 
44
+ def make_preview_with_boxes(
45
+ image_path: str,
46
+ scale_option: str,
47
+ cx_norm: float,
48
+ cy_norm: float,
49
+ ) -> Image.Image:
50
  """
51
+ 1) Open the uploaded image, resize & center-crop to 512×512.
52
+ 2) Let scale_int = int(scale_option.replace("x","")).
53
+ Then the four nested crop‐sizes (in pixels) are:
54
+ size[0] = 512 / (scale_int^1),
55
+ size[1] = 512 / (scale_int^2),
56
+ size[2] = 512 / (scale_int^3),
57
+ size[3] = 512 / (scale_int^4).
58
+ 3) Iteratively compute each crop’s top-left in “original 512×512” space:
59
+ - Start with prev_tl = (0,0), prev_size = 512.
60
+ - For i in [0..3]:
61
+ center_abs_x = prev_tl_x + cx_norm * prev_size
62
+ center_abs_y = prev_tl_y + cy_norm * prev_size
63
+ unc_x0 = center_abs_x - (size[i]/2)
64
+ unc_y0 = center_abs_y - (size[i]/2)
65
+ clamp x0 ∈ [prev_tl_x, prev_tl_x + prev_size - size[i]]
66
+ y0 ∈ [prev_tl_y, prev_tl_y + prev_size - size[i]]
67
+ Draw a rectangle from (x0, y0) to (x0 + size[i], y0 + size[i]).
68
+ Then set prev_tl = (x0, y0), prev_size = size[i].
69
+ 4) Return the PIL image with those four truly nested outlines.
70
  """
71
  try:
72
  orig = Image.open(image_path).convert("RGB")
73
  except Exception as e:
74
+ # On error, return a gray 512×512 with the error text
75
  fallback = Image.new("RGB", (512, 512), (200, 200, 200))
76
  draw = ImageDraw.Draw(fallback)
77
  draw.text((20, 20), f"Error:\n{e}", fill="red")
78
  return fallback
79
 
80
+ # 1) Resize & center-crop to 512×512
81
+ base = resize_and_center_crop(orig, 512)
82
 
83
+ # 2) Compute the four nested crop‐sizes
84
+ scale_int = int(scale_option.replace("x", "")) # e.g. "4x" 4
85
+ if scale_int <= 1:
86
+ # If 1×, then all “nested” sizes are 512 (no real nesting)
87
  sizes = [512, 512, 512, 512]
88
  else:
89
+ sizes = [
90
+ 512 // (scale_int ** (i + 1))
91
+ for i in range(4)
92
+ ]
93
+ # e.g. if scale_int=4 → sizes = [128, 32, 8, 2]
94
 
95
  draw = ImageDraw.Draw(base)
 
 
96
  colors = ["red", "lime", "cyan", "yellow"]
97
+ width = 3
98
+
99
+ # 3) Iteratively compute nested rectangles
100
+ prev_tl_x, prev_tl_y = 0.0, 0.0
101
+ prev_size = 512.0
102
+
103
+ for idx, crop_size in enumerate(sizes):
104
+ # 3.a) Where is the “normalized center” in this current 512×512 region?
105
+ center_abs_x = prev_tl_x + (cx_norm * prev_size)
106
+ center_abs_y = prev_tl_y + (cy_norm * prev_size)
107
+
108
+ # 3.b) Unclamped top-left for this crop
109
+ unc_x0 = center_abs_x - (crop_size / 2.0)
110
+ unc_y0 = center_abs_y - (crop_size / 2.0)
111
+
112
+ # 3.c) Clamp so the crop window stays inside [prev_tl .. prev_tl + prev_size]
113
+ min_x0 = prev_tl_x
114
+ max_x0 = prev_tl_x + prev_size - crop_size
115
+ min_y0 = prev_tl_y
116
+ max_y0 = prev_tl_y + prev_size - crop_size
117
+
118
+ x0 = max(min_x0, min(unc_x0, max_x0))
119
+ y0 = max(min_y0, min(unc_y0, max_y0))
120
+ x1 = x0 + crop_size
121
+ y1 = y0 + crop_size
122
+
123
+ # Draw the rectangle (cast to int for pixels)
124
+ draw.rectangle(
125
+ [(int(x0), int(y0)), (int(x1), int(y1))],
126
+ outline=colors[idx % len(colors)],
127
+ width=width
128
+ )
129
 
130
+ # 3.d) Update for the next iteration
131
+ prev_tl_x, prev_tl_y = x0, y0
132
+ prev_size = crop_size
 
 
 
 
133
 
134
  return base
135
 
136
 
137
+ # ------------------------------------------------------------------
138
+ # HELPER FUNCTION FOR INFERENCE (build a list of identical centers)
139
+ # ------------------------------------------------------------------
140
+
141
  @spaces.GPU(duration=120)
142
+ def run_with_upload(
143
+ uploaded_image_path: str,
144
+ upscale_option: str,
145
+ cx_norm: float,
146
+ cy_norm: float,
147
+ ):
148
+ """
149
+ - upscale_option: "1x" / "2x" / "4x"
150
+ - cx_norm, cy_norm: normalized center coordinates in [0,1]
151
+ The underlying `recursive_multiscale_sr` expects a list of centers
152
+ of length rec_num (default 4). We replicate (cx_norm, cy_norm) four times.
153
+ """
154
+ if uploaded_image_path is None:
155
+ return []
156
 
157
+ upscale_value = int(upscale_option.replace("x", ""))
158
+ rec_num = 4 # match the SR pipeline’s default recursion depth
 
159
 
160
+ centers = [(cx_norm, cy_norm)] * rec_num
161
+
162
+ # Call the modified SR function
163
+ sr_list, _ = recursive_multiscale_sr(
164
+ uploaded_image_path,
165
+ upscale=upscale_value,
166
+ rec_num=rec_num,
167
+ centers=centers,
168
+ )
169
+
170
+ # Return the list of PIL images (Gradio Gallery expects a list)
171
+ return sr_list
172
 
173
 
174
  # ------------------------------------------------------------------
175
+ # BUILD THE GRADIO INTERFACE (two sliders + correct preview)
176
  # ------------------------------------------------------------------
177
 
178
  css = """
 
216
  show_label=False
217
  )
218
 
219
+ # 3) Two sliders for normalized center (0..1)
220
+ center_x = gr.Slider(
221
+ label="Center X (normalized)",
222
+ minimum=0.0,
223
+ maximum=1.0,
224
+ step=0.01,
225
+ value=0.5
226
+ )
227
+ center_y = gr.Slider(
228
+ label="Center Y (normalized)",
229
+ minimum=0.0,
230
+ maximum=1.0,
231
+ step=0.01,
232
+ value=0.5
233
+ )
234
+
235
+ # 4) Button to launch inference
236
  run_button = gr.Button("Chain-of-Zoom it")
237
 
238
+ # 5) Preview (512×512 + four truly nested boxes)
239
  preview_with_box = gr.Image(
240
+ label="Preview (512×512 with nested boxes)",
241
+ type="pil",
242
  interactive=False
243
  )
244
 
245
 
246
  with gr.Column():
247
+ # 6) Gallery to display multiple output images
248
  output_gallery = gr.Gallery(
249
  label="Inference Results",
250
  show_label=True,
 
253
  )
254
 
255
  # ------------------------------------------------------------------
256
+ # CALLBACK #1: update the preview whenever inputs change
257
  # ------------------------------------------------------------------
258
 
259
+ def update_preview(
260
+ img_path: str,
261
+ scale_opt: str,
262
+ cx: float,
263
+ cy: float
264
+ ) -> Image.Image | None:
265
  """
266
+ If no image uploaded, show blank; otherwise, draw four nested boxes
267
+ exactly as the SR pipeline would crop at each recursion.
268
  """
269
  if img_path is None:
270
  return None
271
+ return make_preview_with_boxes(img_path, scale_opt, cx, cy)
272
 
 
273
  upload_image.change(
274
  fn=update_preview,
275
+ inputs=[upload_image, upscale_radio, center_x, center_y],
276
  outputs=[preview_with_box]
277
  )
 
 
278
  upscale_radio.change(
279
  fn=update_preview,
280
+ inputs=[upload_image, upscale_radio, center_x, center_y],
281
+ outputs=[preview_with_box]
282
+ )
283
+ center_x.change(
284
+ fn=update_preview,
285
+ inputs=[upload_image, upscale_radio, center_x, center_y],
286
+ outputs=[preview_with_box]
287
+ )
288
+ center_y.change(
289
+ fn=update_preview,
290
+ inputs=[upload_image, upscale_radio, center_x, center_y],
291
  outputs=[preview_with_box]
292
  )
293
 
294
  # ------------------------------------------------------------------
295
+ # CALLBACK #2: on button‐click, run the SR pipeline
296
  # ------------------------------------------------------------------
297
 
298
  run_button.click(
299
  fn=run_with_upload,
300
+ inputs=[upload_image, upscale_radio, center_x, center_y],
301
  outputs=[output_gallery]
302
  )
303
 
 
306
  # START THE GRADIO SERVER
307
  # ------------------------------------------------------------------
308
 
309
+ demo.launch(share=True, debug=True)