File size: 4,321 Bytes
20cc436
 
634fff1
20cc436
 
5a91dbf
8de6861
5a91dbf
20cc436
 
 
f899183
 
 
 
 
 
 
 
5a91dbf
f899183
20cc436
f899183
 
 
634fff1
 
f899183
5a91dbf
 
f899183
5a91dbf
 
 
 
 
 
 
 
 
 
 
 
f899183
 
 
634fff1
5a91dbf
 
 
 
 
 
 
634fff1
5a91dbf
 
 
 
 
 
 
 
 
634fff1
5a91dbf
634fff1
845613a
5a91dbf
 
634fff1
 
5a91dbf
 
 
 
 
 
 
 
 
 
 
 
 
 
634fff1
5a91dbf
 
 
f899183
5a91dbf
 
634fff1
5a91dbf
20cc436
 
 
634fff1
5a91dbf
634fff1
 
5a91dbf
 
 
 
 
634fff1
 
5a91dbf
 
634fff1
5a91dbf
 
20cc436
 
5a91dbf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
from PIL import Image
from RealESRGAN import RealESRGAN
import gradio as gr
import numpy as np
import tempfile
import time
import os

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def load_model(scale):
    model = RealESRGAN(device, scale=scale)
    weights_path = f'weights/RealESRGAN_x{scale}.pth'
    try:
        model.load_weights(weights_path, download=True)
        print(f"Weights for scale {scale} loaded successfully.")
    except Exception as e:
        print(f"Error loading weights for scale {scale}: {e}")
        model.load_weights(weights_path, download=False)
    return model

model2 = load_model(2)
model4 = load_model(4)
model8 = load_model(8)

def enhance_image(image, scale):
    try:
        print(f"Enhancing image with scale {scale}...")
        start_time = time.time()
        image_np = np.array(image.convert('RGB'))
        print(f"Image converted to numpy array: shape {image_np.shape}, dtype {image_np.dtype}")
        
        if scale == '2x':
            result = model2.predict(image_np)
        elif scale == '4x':
            result = model4.predict(image_np)
        else:
            result = model8.predict(image_np)
            
        enhanced_image = Image.fromarray(np.uint8(result))
        print(f"Image enhanced in {time.time() - start_time:.2f} seconds")
        return enhanced_image
    except Exception as e:
        print(f"Error enhancing image: {e}")
        return image

def muda_dpi(input_image, dpi):
    dpi_tuple = (dpi, dpi)
    image = Image.fromarray(input_image.astype('uint8'), 'RGB')
    temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
    image.save(temp_file, format='JPEG', dpi=dpi_tuple)
    temp_file.close()
    return Image.open(temp_file.name)

def resize_image(input_image, width, height):
    image = Image.fromarray(input_image.astype('uint8'), 'RGB')
    resized_image = image.resize((width, height))
    temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
    resized_image.save(temp_file, format='JPEG')
    temp_file.close()
    return Image.open(temp_file.name)

def process_images(image_files, enhance, scale, adjust_dpi, dpi, resize, width, height):
    processed_images = []
    file_paths = []
    
    for image_file in image_files:
        input_image = np.array(Image.open(image_file).convert('RGB'))
        original_image = Image.fromarray(input_image.astype('uint8'), 'RGB')
        
        if enhance:
            original_image = enhance_image(original_image, scale)
        
        if adjust_dpi:
            original_image = muda_dpi(np.array(original_image), dpi)
            
        if resize:
            original_image = resize_image(np.array(original_image), width, height)
        
        # Sanitize the base filename
        base_name = os.path.basename(image_file.name)
        file_name, _ = os.path.splitext(base_name)
        
        # Remove any characters that aren't alphanumeric, spaces, underscores, or hyphens
        file_name = ''.join(e for e in file_name if e.isalnum() or e in (' ', '_', '-')).strip().replace(' ', '_')
        
        # Create a final file path without unnecessary suffixes
        output_path = os.path.join(tempfile.gettempdir(), f"{file_name}.jpg")
        original_image.save(output_path, format='JPEG')
        
        processed_images.append(original_image)
        file_paths.append(output_path)
    
    return processed_images, file_paths

iface = gr.Interface(
    fn=process_images,
    inputs=[
        gr.Files(label="Upload Image Files"),  # Use gr.Files for multiple file uploads
        gr.Checkbox(label="Enhance Images (ESRGAN)"),
        gr.Radio(['2x', '4x', '8x'], type="value", value='2x', label='Resolution model'),
        gr.Checkbox(label="Adjust DPI"),
        gr.Number(label="DPI", value=300),
        gr.Checkbox(label="Resize"),
        gr.Number(label="Width", value=512),
        gr.Number(label="Height", value=512)
    ],
    outputs=[
        gr.Gallery(label="Final Images"),  # Use gr.Gallery to display multiple images
        gr.Files(label="Download Final Images")
    ],
    title="Multi-Image Enhancer",
    description="Upload multiple images (.jpg, .png), enhance using AI, adjust DPI, resize, and download the final results."
)

iface.launch(debug=True)