Spaces:
Running
on
T4
Running
on
T4
File size: 3,863 Bytes
6358027 8f3d49d 6358027 d26bbd5 6358027 d26bbd5 6358027 d26bbd5 6358027 d26bbd5 6358027 561c629 6358027 561c629 6358027 561c629 6358027 561c629 6358027 561c629 6358027 561c629 6358027 561c629 6358027 561c629 6358027 561c629 6358027 561c629 6358027 561c629 6358027 561c629 6358027 561c629 6358027 561c629 6358027 561c629 6358027 561c629 6358027 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import os
import cv2
import numpy as np
import onnxruntime as ort
import gradio as gr
from PIL import Image
# Path to the model in Hugging Face Space
MODEL_PATH = "pretrained/4xGRL.onnx" # Adjust this if the model is stored in a different location
# Preprocessing function for images (similar to original script)
def preprocess_image(img, target_height=180, target_width=320, crop_for_4x=True, downsample_threshold=720):
''' Preprocess the image to match model input expectations '''
img = np.array(img)
# Convert to RGB (OpenCV uses BGR by default)
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# Resize if necessary (downsample based on the downsample threshold)
h, w, _ = img_rgb.shape
short_side = min(h, w)
# Downsample if the short side exceeds the threshold
if short_side > downsample_threshold:
resize_ratio = short_side / downsample_threshold
img_rgb = cv2.resize(img_rgb, (int(w / resize_ratio), int(h / resize_ratio)), interpolation=cv2.INTER_LINEAR)
# Crop to match 4x scaling if needed
if crop_for_4x:
h, w, _ = img_rgb.shape
if h % 4 != 0:
img_rgb = img_rgb[:4 * (h // 4), :, :]
if w % 4 != 0:
img_rgb = img_rgb[:, :4 * (w // 4), :]
# Resize the image to match the model's expected input size (e.g., 180x320)
img_resized = cv2.resize(img_rgb, (target_width, target_height)) # Resize to 180x320
return img_resized
# Inference function to process image using ONNX model
def inference(img, model_name="4xGRL"):
try:
# Ensure correct dtype for ONNX
weight_dtype = np.float32 # ONNX uses numpy arrays, so use np.float32
if model_name == "4xGRL":
# Load the ONNX model
ort_session = ort.InferenceSession(MODEL_PATH)
# Preprocess the image (resize, crop, etc.)
img_resized = preprocess_image(img)
# Prepare the input in the format expected by the model (e.g., (N, C, H, W))
input_image = np.transpose(img_resized, (2, 0, 1)) # Convert to (C, H, W)
input_image = np.expand_dims(input_image, axis=0) # Add batch dimension
input_image = input_image.astype(weight_dtype) # Convert to float32
# Run the model
ort_inputs = {ort_session.get_inputs()[0].name: input_image}
ort_outs = ort_session.run(None, ort_inputs)
# Post-process the output
output_image = ort_outs[0] # Assuming the model output is in the first position
output_image = np.transpose(output_image.squeeze(), (1, 2, 0)) # Convert to (H, W, C)
output_image = np.clip(output_image, 0, 255).astype(np.uint8) # Ensure valid image range
# Convert output to PIL Image for Gradio
output_pil = Image.fromarray(output_image)
return output_pil
else:
raise Exception("Model not supported")
except Exception as error:
return f"An error occurred: {error}"
# Gradio interface
def create_interface():
with gr.Blocks() as demo:
gr.Markdown("# Anime Super-Resolution using ONNX")
gr.Markdown("Upload an anime image to enhance it using the 4xGRL model.")
# File input for image
with gr.Row():
input_image = gr.Image(type="pil", label="Upload Image", interactive=True)
# Process button
with gr.Row():
process_button = gr.Button("Process Image")
# Output for result image
with gr.Row():
result_image = gr.Image(type="pil", label="Processed Image")
# Functionality for processing image
process_button.click(inference, inputs=input_image, outputs=result_image)
return demo
# Launch the app
demo = create_interface()
demo.launch(share=True)
|