Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import subprocess | |
import os | |
import shutil | |
from pathlib import Path | |
from PIL import Image | |
import spaces | |
# ----------------------------------------------------------------------------- | |
# CONFIGURE THESE PATHS TO MATCH YOUR PROJECT STRUCTURE | |
# ----------------------------------------------------------------------------- | |
INPUT_DIR = "samples" | |
OUTPUT_DIR = "inference_results/coz_vlmprompt" | |
# ----------------------------------------------------------------------------- | |
# HELPER FUNCTION TO RUN INFERENCE AND RETURN THE OUTPUT IMAGE | |
# ----------------------------------------------------------------------------- | |
def run_with_upload(uploaded_image_path): | |
""" | |
1) Clear out INPUT_DIR (so old samples don’t linger). | |
2) Copy the uploaded image into INPUT_DIR. | |
3) Run your inference_coz.py command (which reads from -i INPUT_DIR). | |
4) After it finishes, find the most recently‐modified PNG in OUTPUT_DIR. | |
5) Return a PIL.Image, which Gradio will display. | |
""" | |
# 1) Make sure INPUT_DIR exists; if it does, delete everything inside. | |
os.makedirs(INPUT_DIR, exist_ok=True) | |
for fn in os.listdir(INPUT_DIR): | |
full_path = os.path.join(INPUT_DIR, fn) | |
try: | |
if os.path.isfile(full_path) or os.path.islink(full_path): | |
os.remove(full_path) | |
elif os.path.isdir(full_path): | |
shutil.rmtree(full_path) | |
except Exception as e: | |
print(f"Warning: could not delete {full_path}: {e}") | |
# 2) Copy the uploaded image into INPUT_DIR. | |
# Gradio will give us a path like "/tmp/gradio_xyz.png" | |
if uploaded_image_path is None: | |
return None | |
try: | |
# Open with PIL (this handles JPEG, BMP, TIFF, etc.) | |
pil_img = Image.open(uploaded_image_path).convert("RGB") | |
except Exception as e: | |
print(f"Error: could not open uploaded image: {e}") | |
return None | |
# Save it as "input.png" in our INPUT_DIR | |
save_path = Path(INPUT_DIR) / "input.png" | |
try: | |
pil_img.save(save_path, format="PNG") | |
except Exception as e: | |
print(f"Error: could not save as PNG: {e}") | |
return None | |
# 3) Build and run your inference_coz.py command. | |
# This will block until it completes. | |
cmd = [ | |
"python", "inference_coz.py", | |
"-i", INPUT_DIR, | |
"-o", OUTPUT_DIR, | |
"--rec_type", "recursive_multiscale", | |
"--prompt_type", "vlm", | |
"--upscale", "2", | |
"--lora_path", "ckpt/SR_LoRA/model_20001.pkl", | |
"--vae_path", "ckpt/SR_VAE/vae_encoder_20001.pt", | |
"--pretrained_model_name_or_path", "stabilityai/stable-diffusion-3-medium-diffusers", | |
"--ram_ft_path", "ckpt/DAPE/DAPE.pth", | |
"--ram_path", "ckpt/RAM/ram_swin_large_14m.pth" | |
] | |
try: | |
subprocess.run(cmd, check=True) | |
except subprocess.CalledProcessError as err: | |
# If inference_coz.py crashes, we can print/log the error. | |
print("Inference failed:", err) | |
return None | |
# 4) After it finishes, scan OUTPUT_DIR for .png files. | |
RECUSIVE_DIR = f'{OUTPUT_DIR}/recursive' | |
if not os.path.isdir(RECUSIVE_DIR): | |
return None | |
png_files = [ | |
os.path.join(RECUSIVE_DIR, fn) | |
for fn in os.listdir(RECUSIVE_DIR) | |
if fn.lower().endswith(".png") | |
] | |
if not png_files: | |
return None | |
# 5) Pick the most recently‐modified PNG | |
latest_png = max(png_files, key=os.path.getmtime) | |
# 6) Open and return a PIL.Image. Gradio will display it automatically. | |
try: | |
img = Image.open(latest_png).convert("RGB") | |
except Exception as e: | |
print(f"Error opening {latest_png}: {e}") | |
return None | |
return img | |
# ----------------------------------------------------------------------------- | |
# BUILD THE GRADIO INTERFACE | |
# ----------------------------------------------------------------------------- | |
with gr.Blocks() as demo: | |
gr.Markdown("## Upload an image, then click **Run Inference** to process it.") | |
# 1) Image upload component. We set type="filepath" so the callback | |
# (run_with_upload) will receive a local path to the uploaded file. | |
upload_image = gr.Image( | |
label="Upload your input image", | |
type="filepath" | |
) | |
# 2) A button that the user will click to launch inference. | |
run_button = gr.Button("Run Inference") | |
# 3) An output <Image> where we will show the final PNG. | |
output_image = gr.Image( | |
label="Inference Result", | |
type="pil" # because run_with_upload() returns a PIL.Image | |
) | |
# Wire the button: when clicked, call run_with_upload(upload_image), put | |
# its return value into output_image. | |
run_button.click( | |
fn=run_with_upload, | |
inputs=upload_image, | |
outputs=output_image | |
) | |
# ----------------------------------------------------------------------------- | |
# START THE GRADIO SERVER | |
# ----------------------------------------------------------------------------- | |
demo.launch(share=True) | |