Chain-of-Zoom / app.py
alexnasa's picture
Update app.py
f246a5c verified
raw
history blame
5.78 kB
import gradio as gr
import subprocess
import os
import shutil
from pathlib import Path
from PIL import Image
import spaces
# -----------------------------------------------------------------------------
# CONFIGURE THESE PATHS TO MATCH YOUR PROJECT STRUCTURE
# -----------------------------------------------------------------------------
INPUT_DIR = "samples"
OUTPUT_DIR = "inference_results/coz_vlmprompt"
# -----------------------------------------------------------------------------
# HELPER FUNCTION TO RUN INFERENCE AND RETURN THE OUTPUT IMAGE
# -----------------------------------------------------------------------------
@spaces.GPU()
def run_with_upload(uploaded_image_path):
"""
1) Clear out INPUT_DIR (so old samples don’t linger).
2) Copy the uploaded image into INPUT_DIR.
3) Run your inference_coz.py command (which reads from -i INPUT_DIR).
4) After it finishes, find the most recently‐modified PNG in OUTPUT_DIR.
5) Return a PIL.Image, which Gradio will display.
"""
# 1) Make sure INPUT_DIR exists; if it does, delete everything inside.
os.makedirs(INPUT_DIR, exist_ok=True)
for fn in os.listdir(INPUT_DIR):
full_path = os.path.join(INPUT_DIR, fn)
try:
if os.path.isfile(full_path) or os.path.islink(full_path):
os.remove(full_path)
elif os.path.isdir(full_path):
shutil.rmtree(full_path)
except Exception as e:
print(f"Warning: could not delete {full_path}: {e}")
# 2) Copy the uploaded image into INPUT_DIR.
# Gradio will give us a path like "/tmp/gradio_xyz.png"
if uploaded_image_path is None:
return None
try:
# Open with PIL (this handles JPEG, BMP, TIFF, etc.)
pil_img = Image.open(uploaded_image_path).convert("RGB")
except Exception as e:
print(f"Error: could not open uploaded image: {e}")
return None
# Save it as "input.png" in our INPUT_DIR
save_path = Path(INPUT_DIR) / "input.png"
try:
pil_img.save(save_path, format="PNG")
except Exception as e:
print(f"Error: could not save as PNG: {e}")
return None
# 3) Build and run your inference_coz.py command.
# This will block until it completes.
cmd = [
"python", "inference_coz.py",
"-i", INPUT_DIR,
"-o", OUTPUT_DIR,
"--rec_type", "recursive_multiscale",
"--prompt_type", "vlm",
"--upscale", "2",
"--lora_path", "ckpt/SR_LoRA/model_20001.pkl",
"--vae_path", "ckpt/SR_VAE/vae_encoder_20001.pt",
"--pretrained_model_name_or_path", "stabilityai/stable-diffusion-3-medium-diffusers",
"--ram_ft_path", "ckpt/DAPE/DAPE.pth",
"--ram_path", "ckpt/RAM/ram_swin_large_14m.pth"
]
try:
subprocess.run(cmd, check=True)
except subprocess.CalledProcessError as err:
# If inference_coz.py crashes, we can print/log the error.
print("Inference failed:", err)
return None
# 4) After it finishes, scan OUTPUT_DIR for .png files.
RECUSIVE_DIR = f'{OUTPUT_DIR}/recursive'
if not os.path.isdir(RECUSIVE_DIR):
return None
png_files = [
os.path.join(RECUSIVE_DIR, fn)
for fn in os.listdir(RECUSIVE_DIR)
if fn.lower().endswith(".png")
]
if not png_files:
return None
# 5) Pick the most recently‐modified PNG
latest_png = max(png_files, key=os.path.getmtime)
# 6) Open and return a PIL.Image. Gradio will display it automatically.
try:
img = Image.open(latest_png).convert("RGB")
except Exception as e:
print(f"Error opening {latest_png}: {e}")
return None
return img
# -----------------------------------------------------------------------------
# BUILD THE GRADIO INTERFACE
# -----------------------------------------------------------------------------
css="""
#col-container {
margin: 0 auto;
max-width: 720px;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML(
"""
<div style="text-align: center;">
<h1>Chain-of-Zoom</h1>
<p style="font-size:16px;">Extreme Super-Resolution via Scale Autoregression and Preference Alignment </p>
</div>
<br>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://github.com/bryanswkim/Chain-of-Zoom">
<img src='https://img.shields.io/badge/GitHub-Repo-blue'>
</a>
</div>
"""
)
with gr.Column(elem_id="col-container"):
# 1) Image upload component. We set type="filepath" so the callback
# (run_with_upload) will receive a local path to the uploaded file.
upload_image = gr.Image(
label="Upload your input image",
type="filepath"
)
# 2) A button that the user will click to launch inference.
run_button = gr.Button("Run Inference")
# 3) An output <Image> where we will show the final PNG.
output_image = gr.Image(
label="Inference Result",
type="pil" # because run_with_upload() returns a PIL.Image
)
# Wire the button: when clicked, call run_with_upload(upload_image), put
# its return value into output_image.
run_button.click(
fn=run_with_upload,
inputs=upload_image,
outputs=output_image
)
# -----------------------------------------------------------------------------
# START THE GRADIO SERVER
# -----------------------------------------------------------------------------
demo.launch(share=True)