import gradio as gr import subprocess import os import shutil from pathlib import Path from PIL import Image, ImageDraw import spaces # ------------------------------------------------------------------ # CONFIGURE THESE PATHS TO MATCH YOUR PROJECT STRUCTURE # ------------------------------------------------------------------ INPUT_DIR = "samples" OUTPUT_DIR = "inference_results/coz_vlmprompt" # ------------------------------------------------------------------ # HELPER: Resize & center-crop to 512, preserving aspect ratio # ------------------------------------------------------------------ def resize_and_center_crop(img: Image.Image, size: int) -> Image.Image: """ Resize the input PIL image so that its shorter side == `size`, then center-crop to exactly (size x size). """ w, h = img.size scale = size / min(w, h) new_w, new_h = int(w * scale), int(h * scale) img = img.resize((new_w, new_h), Image.LANCZOS) left = (new_w - size) // 2 top = (new_h - size) // 2 return img.crop((left, top, left + size, top + size)) # ------------------------------------------------------------------ # HELPER: Draw four concentric, centered rectangles on a 512×512 image # ------------------------------------------------------------------ def make_preview_with_boxes(image_path: str, scale_option: str) -> Image.Image: """ 1) Open the uploaded image from disk. 2) Resize & center-crop it to exactly 512×512. 3) Depending on scale_option ("1x","2x","4x"), compute four rectangle sizes: - "1x": [512, 512, 512, 512] - "2x": [256, 128, 64, 32] - "4x": [128, 64, 32, 16] 4) Draw each of those four rectangles (outline only), all centered. 5) Return the modified PIL image. """ try: orig = Image.open(image_path).convert("RGB") except Exception as e: # If something fails, return a plain 512×512 gray image as fallback fallback = Image.new("RGB", (512, 512), (200, 200, 200)) draw = ImageDraw.Draw(fallback) draw.text((20, 20), f"Error:\n{e}", fill="red") return fallback # 1. Resize & center-crop to 512×512 base = resize_and_center_crop(orig, 512) # now `base.size == (512,512)` # 2. Determine the four box sizes scale_int = int(scale_option.replace("x", "")) # e.g. "2x" -> 2 if scale_int == 1: sizes = [512, 512, 512, 512] else: # For scale=2: sizes = [512//2, 512//(2*2), 512//(2*4), 512//(2*8)] -> [256,128,64,32] # For scale=4: sizes = [512//4, 512//(4*2), 512//(4*4), 512//(4*8)] -> [128,64,32,16] sizes = [512 // (scale_int * (2 ** i)) for i in range(4)] draw = ImageDraw.Draw(base) # 3. Outline color cycle (you can change these or use just one color) colors = ["red", "lime", "cyan", "yellow"] width = 3 # thickness of each rectangle’s outline for idx, s in enumerate(sizes): # Compute top-left corner so that box is centered in 512×512 x0 = (512 - s) // 2 y0 = (512 - s) // 2 x1 = x0 + s y1 = y0 + s draw.rectangle([(x0, y0), (x1, y1)], outline=colors[idx % len(colors)], width=width) return base # ------------------------------------------------------------------ # HELPER FUNCTIONS FOR INFERENCE & CAPTION (unchanged from your original) # ------------------------------------------------------------------ @spaces.GPU(duration=120) def run_with_upload(uploaded_image_path, upscale_option): """ 1) Clear INPUT_DIR 2) Save the uploaded file as input.png in INPUT_DIR 3) Read `upscale_option` (e.g. "1x", "2x", "4x") → turn it into "1","2","4" 4) Call inference_coz.py with `--upscale ` 5) Return the FOUR output‐PNG file‐paths as a Python list, so that Gradio's Gallery can display them. """ # ———————————————————————————————————————————————————————————— # (Copy‐paste exactly your existing code here; no changes needed) # ———————————————————————————————————————————————————————————— os.makedirs(INPUT_DIR, exist_ok=True) for fn in os.listdir(INPUT_DIR): full_path = os.path.join(INPUT_DIR, fn) try: if os.path.isfile(full_path) or os.path.islink(full_path): os.remove(full_path) elif os.path.isdir(full_path): shutil.rmtree(full_path) except Exception as e: print(f"Warning: could not delete {full_path}: {e}") if uploaded_image_path is None: return [] try: pil_img = Image.open(uploaded_image_path).convert("RGB") except Exception as e: print(f"Error: could not open uploaded image: {e}") return [] save_path = Path(INPUT_DIR) / "input.png" try: pil_img.save(save_path, format="PNG") except Exception as e: print(f"Error: could not save as PNG: {e}") return [] upscale_value = upscale_option.replace("x", "") # e.g. "2x" → "2" cmd = [ "python", "inference_coz.py", "-i", INPUT_DIR, "-o", OUTPUT_DIR, "--rec_type", "recursive_multiscale", "--prompt_type", "vlm", "--upscale", upscale_value, "--lora_path", "ckpt/SR_LoRA/model_20001.pkl", "--vae_path", "ckpt/SR_VAE/vae_encoder_20001.pt", "--pretrained_model_name_or_path", "stabilityai/stable-diffusion-3-medium-diffusers", "--ram_ft_path", "ckpt/DAPE/DAPE.pth", "--ram_path", "ckpt/RAM/ram_swin_large_14m.pth" ] try: subprocess.run(cmd, check=True) except subprocess.CalledProcessError as err: print("Inference failed:", err) return [] per_sample_dir = os.path.join(OUTPUT_DIR, "per-sample", "input") expected_files = [ os.path.join(per_sample_dir, f"{i}.png") for i in range(1, 5) ] for fp in expected_files: if not os.path.isfile(fp): print(f"Warning: expected file not found: {fp}") return [] return expected_files def get_caption(src_gallery, evt: gr.SelectData): """ Given a clicked‐on image in the gallery, read the corresponding .txt in .../per-sample/input/txt and return its contents. """ if not src_gallery or not os.path.isfile(src_gallery[evt.index][0]): return "No caption available." selected_image_path = src_gallery[evt.index][0] base = os.path.basename(selected_image_path) # e.g. "2.png" stem = os.path.splitext(base)[0] # e.g. "2" txt_folder = os.path.join(OUTPUT_DIR, "per-sample", "input", "txt") txt_path = os.path.join(txt_folder, f"{int(stem) - 1}.txt") if not os.path.isfile(txt_path): return f"Caption file not found: {int(stem) - 1}.txt" try: with open(txt_path, "r", encoding="utf-8") as f: caption = f.read().strip() return caption if caption else "(Caption file is empty.)" except Exception as e: return f"Error reading caption: {e}" # ------------------------------------------------------------------ # BUILD THE GRADIO INTERFACE (with updated callbacks) # ------------------------------------------------------------------ css = """ #col-container { margin: 0 auto; max-width: 1024px; } """ with gr.Blocks(css=css) as demo: gr.HTML( """

Chain-of-Zoom

Extreme Super-Resolution via Scale Autoregression and Preference Alignment


""" ) with gr.Column(elem_id="col-container"): with gr.Row(): with gr.Column(): # 1) Image upload component upload_image = gr.Image( label="Upload your input image", type="filepath" ) # 2) Radio for choosing 1× / 2× / 4× upscaling upscale_radio = gr.Radio( choices=["1x", "2x", "4x"], value="2x", show_label=False ) # 3) Button to launch inference run_button = gr.Button("Chain-of-Zoom it") # 4) Show the 512×512 preview with four centered rectangles preview_with_box = gr.Image( label="Preview (512×512 with centered boxes)", type="pil", # we’ll return a PIL.Image from our function interactive=False ) with gr.Column(): # 5) Gallery to display multiple output images output_gallery = gr.Gallery( label="Inference Results", show_label=True, elem_id="gallery", columns=[2], rows=[2] ) # 6) Textbox under the gallery for showing captions caption_text = gr.Textbox( label="Caption", lines=4, placeholder="Click on any image above to see its caption here." ) # ------------------------------------------------------------------ # CALLBACK #1: Whenever the user uploads or changes the radio, update preview # ------------------------------------------------------------------ def update_preview(img_path, scale_opt): """ If there's no image uploaded yet, return None (Gradio will show blank). Otherwise, draw the resized 512×512 + four boxes and return it. """ if img_path is None: return None return make_preview_with_boxes(img_path, scale_opt) # When the user uploads a new file: upload_image.change( fn=update_preview, inputs=[upload_image, upscale_radio], outputs=[preview_with_box] ) # Also trigger preview redraw if they switch 1×/2×/4× after uploading: upscale_radio.change( fn=update_preview, inputs=[upload_image, upscale_radio], outputs=[preview_with_box] ) # ------------------------------------------------------------------ # CALLBACK #2: When “Chain-of-Zoom it” is clicked, run inference # ------------------------------------------------------------------ run_button.click( fn=run_with_upload, inputs=[upload_image, upscale_radio], outputs=[output_gallery] ) # ------------------------------------------------------------------ # CALLBACK #3: When an image in the gallery is clicked, show its caption # ------------------------------------------------------------------ output_gallery.select( fn=get_caption, inputs=[output_gallery], outputs=[caption_text] ) # ------------------------------------------------------------------ # START THE GRADIO SERVER # ------------------------------------------------------------------ demo.launch(share=True)