File size: 3,863 Bytes
6358027
 
 
 
8f3d49d
6358027
d26bbd5
6358027
 
d26bbd5
6358027
 
 
 
d26bbd5
6358027
 
 
 
 
 
d26bbd5
6358027
 
 
 
 
 
561c629
6358027
561c629
6358027
561c629
6358027
561c629
6358027
 
561c629
6358027
561c629
6358027
 
 
 
 
 
 
 
 
561c629
6358027
 
561c629
6358027
 
 
 
561c629
6358027
 
 
561c629
6358027
 
 
 
561c629
6358027
 
561c629
6358027
561c629
6358027
 
561c629
6358027
 
561c629
6358027
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
561c629
6358027
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import cv2
import numpy as np
import onnxruntime as ort
import gradio as gr
from PIL import Image

# Path to the model in Hugging Face Space
MODEL_PATH = "pretrained/4xGRL.onnx"  # Adjust this if the model is stored in a different location

# Preprocessing function for images (similar to original script)
def preprocess_image(img, target_height=180, target_width=320, crop_for_4x=True, downsample_threshold=720):
    ''' Preprocess the image to match model input expectations '''
    img = np.array(img)

    # Convert to RGB (OpenCV uses BGR by default)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    # Resize if necessary (downsample based on the downsample threshold)
    h, w, _ = img_rgb.shape
    short_side = min(h, w)

    # Downsample if the short side exceeds the threshold
    if short_side > downsample_threshold:
        resize_ratio = short_side / downsample_threshold
        img_rgb = cv2.resize(img_rgb, (int(w / resize_ratio), int(h / resize_ratio)), interpolation=cv2.INTER_LINEAR)
    
    # Crop to match 4x scaling if needed
    if crop_for_4x:
        h, w, _ = img_rgb.shape
        if h % 4 != 0:
            img_rgb = img_rgb[:4 * (h // 4), :, :]
        if w % 4 != 0:
            img_rgb = img_rgb[:, :4 * (w // 4), :]
    
    # Resize the image to match the model's expected input size (e.g., 180x320)
    img_resized = cv2.resize(img_rgb, (target_width, target_height))  # Resize to 180x320
    
    return img_resized

# Inference function to process image using ONNX model
def inference(img, model_name="4xGRL"):
    try:
        # Ensure correct dtype for ONNX
        weight_dtype = np.float32  # ONNX uses numpy arrays, so use np.float32
        
        if model_name == "4xGRL":
            # Load the ONNX model
            ort_session = ort.InferenceSession(MODEL_PATH)

            # Preprocess the image (resize, crop, etc.)
            img_resized = preprocess_image(img)

            # Prepare the input in the format expected by the model (e.g., (N, C, H, W))
            input_image = np.transpose(img_resized, (2, 0, 1))  # Convert to (C, H, W)
            input_image = np.expand_dims(input_image, axis=0)  # Add batch dimension
            input_image = input_image.astype(weight_dtype)  # Convert to float32

            # Run the model
            ort_inputs = {ort_session.get_inputs()[0].name: input_image}
            ort_outs = ort_session.run(None, ort_inputs)

            # Post-process the output
            output_image = ort_outs[0]  # Assuming the model output is in the first position
            output_image = np.transpose(output_image.squeeze(), (1, 2, 0))  # Convert to (H, W, C)
            output_image = np.clip(output_image, 0, 255).astype(np.uint8)  # Ensure valid image range

            # Convert output to PIL Image for Gradio
            output_pil = Image.fromarray(output_image)

            return output_pil

        else:
            raise Exception("Model not supported")

    except Exception as error:
        return f"An error occurred: {error}"

# Gradio interface
def create_interface():
    with gr.Blocks() as demo:
        gr.Markdown("# Anime Super-Resolution using ONNX")
        gr.Markdown("Upload an anime image to enhance it using the 4xGRL model.")

        # File input for image
        with gr.Row():
            input_image = gr.Image(type="pil", label="Upload Image", interactive=True)

        # Process button
        with gr.Row():
            process_button = gr.Button("Process Image")

        # Output for result image
        with gr.Row():
            result_image = gr.Image(type="pil", label="Processed Image")

        # Functionality for processing image
        process_button.click(inference, inputs=input_image, outputs=result_image)

    return demo

# Launch the app
demo = create_interface()
demo.launch(share=True)