import gradio as gr import subprocess import os import shutil from pathlib import Path from PIL import Image import spaces # ----------------------------------------------------------------------------- # CONFIGURE THESE PATHS TO MATCH YOUR PROJECT STRUCTURE # ----------------------------------------------------------------------------- INPUT_DIR = "samples" OUTPUT_DIR = "inference_results/coz_vlmprompt" # ----------------------------------------------------------------------------- # HELPER FUNCTION TO RUN INFERENCE AND RETURN THE OUTPUT IMAGE # ----------------------------------------------------------------------------- @spaces.GPU() def run_with_upload(uploaded_image_path): """ 1) Clear out INPUT_DIR (so old samples don’t linger). 2) Copy the uploaded image into INPUT_DIR. 3) Run your inference_coz.py command (which reads from -i INPUT_DIR). 4) After it finishes, find the most recently‐modified PNG in OUTPUT_DIR. 5) Return a PIL.Image, which Gradio will display. """ # 1) Make sure INPUT_DIR exists; if it does, delete everything inside. os.makedirs(INPUT_DIR, exist_ok=True) for fn in os.listdir(INPUT_DIR): full_path = os.path.join(INPUT_DIR, fn) try: if os.path.isfile(full_path) or os.path.islink(full_path): os.remove(full_path) elif os.path.isdir(full_path): shutil.rmtree(full_path) except Exception as e: print(f"Warning: could not delete {full_path}: {e}") # 2) Copy the uploaded image into INPUT_DIR. # Gradio will give us a path like "/tmp/gradio_xyz.png" if uploaded_image_path is None: return None try: # Open with PIL (this handles JPEG, BMP, TIFF, etc.) pil_img = Image.open(uploaded_image_path).convert("RGB") except Exception as e: print(f"Error: could not open uploaded image: {e}") return None # Save it as "input.png" in our INPUT_DIR save_path = Path(INPUT_DIR) / "input.png" try: pil_img.save(save_path, format="PNG") except Exception as e: print(f"Error: could not save as PNG: {e}") return None # 3) Build and run your inference_coz.py command. # This will block until it completes. cmd = [ "python", "inference_coz.py", "-i", INPUT_DIR, "-o", OUTPUT_DIR, "--rec_type", "recursive_multiscale", "--prompt_type", "vlm", "--upscale", "2", "--lora_path", "ckpt/SR_LoRA/model_20001.pkl", "--vae_path", "ckpt/SR_VAE/vae_encoder_20001.pt", "--pretrained_model_name_or_path", "stabilityai/stable-diffusion-3-medium-diffusers", "--ram_ft_path", "ckpt/DAPE/DAPE.pth", "--ram_path", "ckpt/RAM/ram_swin_large_14m.pth" ] try: subprocess.run(cmd, check=True) except subprocess.CalledProcessError as err: # If inference_coz.py crashes, we can print/log the error. print("Inference failed:", err) return None # 4) After it finishes, scan OUTPUT_DIR for .png files. RECUSIVE_DIR = f'{OUTPUT_DIR}/recursive' if not os.path.isdir(RECUSIVE_DIR): return None png_files = [ os.path.join(RECUSIVE_DIR, fn) for fn in os.listdir(RECUSIVE_DIR) if fn.lower().endswith(".png") ] if not png_files: return None # 5) Pick the most recently‐modified PNG latest_png = max(png_files, key=os.path.getmtime) # 6) Open and return a PIL.Image. Gradio will display it automatically. try: img = Image.open(latest_png).convert("RGB") except Exception as e: print(f"Error opening {latest_png}: {e}") return None return img # ----------------------------------------------------------------------------- # BUILD THE GRADIO INTERFACE # ----------------------------------------------------------------------------- with gr.Blocks() as demo: gr.Markdown("## Upload an image, then click **Run Inference** to process it.") # 1) Image upload component. We set type="filepath" so the callback # (run_with_upload) will receive a local path to the uploaded file. upload_image = gr.Image( label="Upload your input image", type="filepath" ) # 2) A button that the user will click to launch inference. run_button = gr.Button("Run Inference") # 3) An output where we will show the final PNG. output_image = gr.Image( label="Inference Result", type="pil" # because run_with_upload() returns a PIL.Image ) # Wire the button: when clicked, call run_with_upload(upload_image), put # its return value into output_image. run_button.click( fn=run_with_upload, inputs=upload_image, outputs=output_image ) # ----------------------------------------------------------------------------- # START THE GRADIO SERVER # ----------------------------------------------------------------------------- demo.launch(share=True)