abiabidali's picture
Update app.py
5a91dbf verified
import torch
from PIL import Image
from RealESRGAN import RealESRGAN
import gradio as gr
import numpy as np
import tempfile
import time
import os
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load_model(scale):
model = RealESRGAN(device, scale=scale)
weights_path = f'weights/RealESRGAN_x{scale}.pth'
try:
model.load_weights(weights_path, download=True)
print(f"Weights for scale {scale} loaded successfully.")
except Exception as e:
print(f"Error loading weights for scale {scale}: {e}")
model.load_weights(weights_path, download=False)
return model
model2 = load_model(2)
model4 = load_model(4)
model8 = load_model(8)
def enhance_image(image, scale):
try:
print(f"Enhancing image with scale {scale}...")
start_time = time.time()
image_np = np.array(image.convert('RGB'))
print(f"Image converted to numpy array: shape {image_np.shape}, dtype {image_np.dtype}")
if scale == '2x':
result = model2.predict(image_np)
elif scale == '4x':
result = model4.predict(image_np)
else:
result = model8.predict(image_np)
enhanced_image = Image.fromarray(np.uint8(result))
print(f"Image enhanced in {time.time() - start_time:.2f} seconds")
return enhanced_image
except Exception as e:
print(f"Error enhancing image: {e}")
return image
def muda_dpi(input_image, dpi):
dpi_tuple = (dpi, dpi)
image = Image.fromarray(input_image.astype('uint8'), 'RGB')
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
image.save(temp_file, format='JPEG', dpi=dpi_tuple)
temp_file.close()
return Image.open(temp_file.name)
def resize_image(input_image, width, height):
image = Image.fromarray(input_image.astype('uint8'), 'RGB')
resized_image = image.resize((width, height))
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
resized_image.save(temp_file, format='JPEG')
temp_file.close()
return Image.open(temp_file.name)
def process_images(image_files, enhance, scale, adjust_dpi, dpi, resize, width, height):
processed_images = []
file_paths = []
for image_file in image_files:
input_image = np.array(Image.open(image_file).convert('RGB'))
original_image = Image.fromarray(input_image.astype('uint8'), 'RGB')
if enhance:
original_image = enhance_image(original_image, scale)
if adjust_dpi:
original_image = muda_dpi(np.array(original_image), dpi)
if resize:
original_image = resize_image(np.array(original_image), width, height)
# Sanitize the base filename
base_name = os.path.basename(image_file.name)
file_name, _ = os.path.splitext(base_name)
# Remove any characters that aren't alphanumeric, spaces, underscores, or hyphens
file_name = ''.join(e for e in file_name if e.isalnum() or e in (' ', '_', '-')).strip().replace(' ', '_')
# Create a final file path without unnecessary suffixes
output_path = os.path.join(tempfile.gettempdir(), f"{file_name}.jpg")
original_image.save(output_path, format='JPEG')
processed_images.append(original_image)
file_paths.append(output_path)
return processed_images, file_paths
iface = gr.Interface(
fn=process_images,
inputs=[
gr.Files(label="Upload Image Files"), # Use gr.Files for multiple file uploads
gr.Checkbox(label="Enhance Images (ESRGAN)"),
gr.Radio(['2x', '4x', '8x'], type="value", value='2x', label='Resolution model'),
gr.Checkbox(label="Adjust DPI"),
gr.Number(label="DPI", value=300),
gr.Checkbox(label="Resize"),
gr.Number(label="Width", value=512),
gr.Number(label="Height", value=512)
],
outputs=[
gr.Gallery(label="Final Images"), # Use gr.Gallery to display multiple images
gr.Files(label="Download Final Images")
],
title="Multi-Image Enhancer",
description="Upload multiple images (.jpg, .png), enhance using AI, adjust DPI, resize, and download the final results."
)
iface.launch(debug=True)