File size: 11,693 Bytes
780320d
 
 
 
 
d73c075
780320d
 
d73c075
780320d
d73c075
780320d
87c1890
 
780320d
d73c075
 
 
780320d
d73c075
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87c1890
24ee135
780320d
24ee135
 
d73c075
24ee135
87c1890
d73c075
780320d
d73c075
 
 
780320d
 
 
 
 
 
 
 
 
 
 
 
 
87c1890
780320d
 
 
 
87c1890
780320d
 
 
 
 
87c1890
780320d
24ee135
780320d
 
 
 
 
 
24ee135
780320d
 
 
 
 
 
 
 
 
 
87c1890
780320d
24ee135
87c1890
 
 
 
24ee135
 
 
87c1890
 
780320d
 
87c1890
 
d73c075
 
87c1890
d73c075
87c1890
 
d73c075
 
 
87c1890
 
 
 
 
 
 
 
 
 
 
780320d
 
d73c075
 
 
780320d
87c1890
780320d
 
24ee135
780320d
 
 
 
 
 
 
87c1890
 
 
 
 
 
 
 
 
 
 
780320d
 
 
24ee135
87c1890
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d73c075
 
 
 
 
 
 
 
87c1890
d73c075
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87c1890
 
 
 
 
 
 
d73c075
 
 
 
87c1890
 
 
 
 
780320d
d73c075
780320d
d73c075
780320d
87c1890
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
import gradio as gr
import subprocess
import os
import shutil
from pathlib import Path
from PIL import Image, ImageDraw
import spaces

# ------------------------------------------------------------------
# CONFIGURE THESE PATHS TO MATCH YOUR PROJECT STRUCTURE
# ------------------------------------------------------------------

INPUT_DIR   = "samples"
OUTPUT_DIR  = "inference_results/coz_vlmprompt"

# ------------------------------------------------------------------
# HELPER: Resize & center-crop to 512, preserving aspect ratio
# ------------------------------------------------------------------

def resize_and_center_crop(img: Image.Image, size: int) -> Image.Image:
    """
    Resize the input PIL image so that its shorter side == `size`,
    then center-crop to exactly (size x size).
    """
    w, h = img.size
    scale = size / min(w, h)
    new_w, new_h = int(w * scale), int(h * scale)
    img = img.resize((new_w, new_h), Image.LANCZOS)

    left = (new_w - size) // 2
    top  = (new_h - size) // 2
    return img.crop((left, top, left + size, top + size))


# ------------------------------------------------------------------
# HELPER: Draw four concentric, centered rectangles on a 512Γ—512 image
# ------------------------------------------------------------------

def make_preview_with_boxes(image_path: str, scale_option: str) -> Image.Image:
    """
    1) Open the uploaded image from disk.
    2) Resize & center-crop it to exactly 512Γ—512.
    3) Depending on scale_option ("1x","2x","4x"), compute four rectangle sizes:
       - "1x": [512, 512, 512, 512]
       - "2x": [256, 128, 64, 32]
       - "4x": [128, 64, 32, 16]
    4) Draw each of those four rectangles (outline only), all centered.
    5) Return the modified PIL image.
    """
    try:
        orig = Image.open(image_path).convert("RGB")
    except Exception as e:
        # If something fails, return a plain 512Γ—512 gray image as fallback
        fallback = Image.new("RGB", (512, 512), (200, 200, 200))
        draw = ImageDraw.Draw(fallback)
        draw.text((20, 20), f"Error:\n{e}", fill="red")
        return fallback

    # 1. Resize & center-crop to 512Γ—512
    base = resize_and_center_crop(orig, 512)  # now `base.size == (512,512)`

    # 2. Determine the four box sizes
    scale_int = int(scale_option.replace("x", ""))  # e.g. "2x" -> 2
    if scale_int == 1:
        sizes = [512, 512, 512, 512]
    else:
        # For scale=2: sizes = [512//2, 512//(2*2), 512//(2*4), 512//(2*8)] -> [256,128,64,32]
        # For scale=4: sizes = [512//4, 512//(4*2), 512//(4*4), 512//(4*8)] -> [128,64,32,16]
        sizes = [512 // (scale_int * (2 ** i)) for i in range(4)]

    draw = ImageDraw.Draw(base)

    # 3. Outline color cycle (you can change these or use just one color)
    colors = ["red", "lime", "cyan", "yellow"]
    width = 3  # thickness of each rectangle’s outline

    for idx, s in enumerate(sizes):
        # Compute top-left corner so that box is centered in 512Γ—512
        x0 = (512 - s) // 2
        y0 = (512 - s) // 2
        x1 = x0 + s
        y1 = y0 + s
        draw.rectangle([(x0, y0), (x1, y1)], outline=colors[idx % len(colors)], width=width)

    return base


# ------------------------------------------------------------------
# HELPER FUNCTIONS FOR INFERENCE & CAPTION (unchanged from your original)
# ------------------------------------------------------------------
@spaces.GPU(duration=120)
def run_with_upload(uploaded_image_path, upscale_option):
    """
    1) Clear INPUT_DIR  
    2) Save the uploaded file as input.png in INPUT_DIR  
    3) Read `upscale_option` (e.g. "1x", "2x", "4x") β†’ turn it into "1","2","4"  
    4) Call inference_coz.py with `--upscale <that_value>`  
    5) Return the FOUR output‐PNG file‐paths as a Python list, so that Gradio's Gallery
       can display them.
    """
    # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
    # (Copy‐paste exactly your existing code here; no changes needed)
    # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”

    os.makedirs(INPUT_DIR, exist_ok=True)
    for fn in os.listdir(INPUT_DIR):
        full_path = os.path.join(INPUT_DIR, fn)
        try:
            if os.path.isfile(full_path) or os.path.islink(full_path):
                os.remove(full_path)
            elif os.path.isdir(full_path):
                shutil.rmtree(full_path)
        except Exception as e:
            print(f"Warning: could not delete {full_path}: {e}")

    if uploaded_image_path is None:
        return []
    try:
        pil_img = Image.open(uploaded_image_path).convert("RGB")
    except Exception as e:
        print(f"Error: could not open uploaded image: {e}")
        return []
    save_path = Path(INPUT_DIR) / "input.png"
    try:
        pil_img.save(save_path, format="PNG")
    except Exception as e:
        print(f"Error: could not save as PNG: {e}")
        return []

    upscale_value = upscale_option.replace("x", "")  # e.g. "2x" β†’ "2"
    cmd = [
        "python", "inference_coz.py",
        "-i", INPUT_DIR,
        "-o", OUTPUT_DIR,
        "--rec_type", "recursive_multiscale",
        "--prompt_type", "vlm",
        "--upscale", upscale_value,
        "--lora_path", "ckpt/SR_LoRA/model_20001.pkl",
        "--vae_path", "ckpt/SR_VAE/vae_encoder_20001.pt",
        "--pretrained_model_name_or_path", "stabilityai/stable-diffusion-3-medium-diffusers",
        "--ram_ft_path", "ckpt/DAPE/DAPE.pth",
        "--ram_path", "ckpt/RAM/ram_swin_large_14m.pth"
    ]
    try:
        subprocess.run(cmd, check=True)
    except subprocess.CalledProcessError as err:
        print("Inference failed:", err)
        return []

    per_sample_dir = os.path.join(OUTPUT_DIR, "per-sample", "input")
    expected_files = [
        os.path.join(per_sample_dir, f"{i}.png")
        for i in range(1, 5)
    ]
    for fp in expected_files:
        if not os.path.isfile(fp):
            print(f"Warning: expected file not found: {fp}")
            return []
    return expected_files


def get_caption(src_gallery, evt: gr.SelectData):
    """
    Given a clicked‐on image in the gallery, read the corresponding .txt in
    .../per-sample/input/txt and return its contents.
    """
    if not src_gallery or not os.path.isfile(src_gallery[evt.index][0]):
        return "No caption available."

    selected_image_path = src_gallery[evt.index][0]
    base = os.path.basename(selected_image_path)  # e.g. "2.png"
    stem = os.path.splitext(base)[0]              # e.g. "2"
    txt_folder = os.path.join(OUTPUT_DIR, "per-sample", "input", "txt")
    txt_path = os.path.join(txt_folder, f"{int(stem) - 1}.txt")

    if not os.path.isfile(txt_path):
        return f"Caption file not found: {int(stem) - 1}.txt"
    try:
        with open(txt_path, "r", encoding="utf-8") as f:
            caption = f.read().strip()
        return caption if caption else "(Caption file is empty.)"
    except Exception as e:
        return f"Error reading caption: {e}"


# ------------------------------------------------------------------
# BUILD THE GRADIO INTERFACE (with updated callbacks)
# ------------------------------------------------------------------

css = """
#col-container {
    margin: 0 auto;
    max-width: 1024px;
}
"""

with gr.Blocks(css=css) as demo:

    gr.HTML(
        """
        <div style="text-align: center;">
            <h1>Chain-of-Zoom</h1>
            <p style="font-size:16px;">Extreme Super-Resolution via Scale Autoregression and Preference Alignment</p>
        </div>
        <br>
        <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
            <a href="https://github.com/bryanswkim/Chain-of-Zoom">
                <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
            </a>
        </div>
        """
    )

    with gr.Column(elem_id="col-container"):

        with gr.Row():
            with gr.Column():
                # 1) Image upload component
                upload_image = gr.Image(
                    label="Upload your input image",
                    type="filepath"
                )

                # 2) Radio for choosing 1Γ— / 2Γ— / 4Γ— upscaling
                upscale_radio = gr.Radio(
                    choices=["1x", "2x", "4x"],
                    value="2x",
                    show_label=False
                )

                # 3) Button to launch inference
                run_button = gr.Button("Chain-of-Zoom it")

                # 4) Show the 512Γ—512 preview with four centered rectangles
                preview_with_box = gr.Image(
                    label="Preview (512Γ—512 with centered boxes)",
                    type="pil",        # we’ll return a PIL.Image from our function
                    interactive=False
                )


            with gr.Column():
                # 5) Gallery to display multiple output images
                output_gallery = gr.Gallery(
                    label="Inference Results",
                    show_label=True,
                    elem_id="gallery",
                    columns=[2], rows=[2]
                )

                # 6) Textbox under the gallery for showing captions
                caption_text = gr.Textbox(
                    label="Caption",
                    lines=4,
                    placeholder="Click on any image above to see its caption here."
                )

        # ------------------------------------------------------------------
        # CALLBACK #1: Whenever the user uploads or changes the radio, update preview
        # ------------------------------------------------------------------

        def update_preview(img_path, scale_opt):
            """
            If there's no image uploaded yet, return None (Gradio will show blank).
            Otherwise, draw the resized 512Γ—512 + four boxes and return it.
            """
            if img_path is None:
                return None
            return make_preview_with_boxes(img_path, scale_opt)

        # When the user uploads a new file:
        upload_image.change(
            fn=update_preview,
            inputs=[upload_image, upscale_radio],
            outputs=[preview_with_box]
        )

        # Also trigger preview redraw if they switch 1Γ—/2Γ—/4Γ— after uploading:
        upscale_radio.change(
            fn=update_preview,
            inputs=[upload_image, upscale_radio],
            outputs=[preview_with_box]
        )

        # ------------------------------------------------------------------
        # CALLBACK #2: When β€œChain-of-Zoom it” is clicked, run inference
        # ------------------------------------------------------------------

        run_button.click(
            fn=run_with_upload,
            inputs=[upload_image, upscale_radio],
            outputs=[output_gallery]
        )

        # ------------------------------------------------------------------
        # CALLBACK #3: When an image in the gallery is clicked, show its caption
        # ------------------------------------------------------------------

        output_gallery.select(
            fn=get_caption,
            inputs=[output_gallery],
            outputs=[caption_text]
        )

# ------------------------------------------------------------------
# START THE GRADIO SERVER
# ------------------------------------------------------------------

demo.launch(share=True)