import os import cv2 import numpy as np import onnxruntime as ort import gradio as gr from PIL import Image # Path to the model in Hugging Face Space MODEL_PATH = "pretrained/4xGRL.onnx" # Adjust this if the model is stored in a different location # Preprocessing function for images (similar to original script) def preprocess_image(img, target_height=180, target_width=320, crop_for_4x=True, downsample_threshold=720): ''' Preprocess the image to match model input expectations ''' img = np.array(img) # Convert to RGB (OpenCV uses BGR by default) img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # Resize if necessary (downsample based on the downsample threshold) h, w, _ = img_rgb.shape short_side = min(h, w) # Downsample if the short side exceeds the threshold if short_side > downsample_threshold: resize_ratio = short_side / downsample_threshold img_rgb = cv2.resize(img_rgb, (int(w / resize_ratio), int(h / resize_ratio)), interpolation=cv2.INTER_LINEAR) # Crop to match 4x scaling if needed if crop_for_4x: h, w, _ = img_rgb.shape if h % 4 != 0: img_rgb = img_rgb[:4 * (h // 4), :, :] if w % 4 != 0: img_rgb = img_rgb[:, :4 * (w // 4), :] # Resize the image to match the model's expected input size (e.g., 180x320) img_resized = cv2.resize(img_rgb, (target_width, target_height)) # Resize to 180x320 return img_resized # Inference function to process image using ONNX model def inference(img, model_name="4xGRL"): try: # Ensure correct dtype for ONNX weight_dtype = np.float32 # ONNX uses numpy arrays, so use np.float32 if model_name == "4xGRL": # Load the ONNX model ort_session = ort.InferenceSession(MODEL_PATH) # Preprocess the image (resize, crop, etc.) img_resized = preprocess_image(img) # Prepare the input in the format expected by the model (e.g., (N, C, H, W)) input_image = np.transpose(img_resized, (2, 0, 1)) # Convert to (C, H, W) input_image = np.expand_dims(input_image, axis=0) # Add batch dimension input_image = input_image.astype(weight_dtype) # Convert to float32 # Run the model ort_inputs = {ort_session.get_inputs()[0].name: input_image} ort_outs = ort_session.run(None, ort_inputs) # Post-process the output output_image = ort_outs[0] # Assuming the model output is in the first position output_image = np.transpose(output_image.squeeze(), (1, 2, 0)) # Convert to (H, W, C) output_image = np.clip(output_image, 0, 255).astype(np.uint8) # Ensure valid image range # Convert output to PIL Image for Gradio output_pil = Image.fromarray(output_image) return output_pil else: raise Exception("Model not supported") except Exception as error: return f"An error occurred: {error}" # Gradio interface def create_interface(): with gr.Blocks() as demo: gr.Markdown("# Anime Super-Resolution using ONNX") gr.Markdown("Upload an anime image to enhance it using the 4xGRL model.") # File input for image with gr.Row(): input_image = gr.Image(type="pil", label="Upload Image", interactive=True) # Process button with gr.Row(): process_button = gr.Button("Process Image") # Output for result image with gr.Row(): result_image = gr.Image(type="pil", label="Processed Image") # Functionality for processing image process_button.click(inference, inputs=input_image, outputs=result_image) return demo # Launch the app demo = create_interface() demo.launch(share=True)