Spaces:
Build error
Build error
import torch | |
from PIL import Image | |
from RealESRGAN import RealESRGAN | |
import gradio as gr | |
import numpy as np | |
import tempfile | |
import time | |
import os | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
def load_model(scale): | |
model = RealESRGAN(device, scale=scale) | |
weights_path = f'weights/RealESRGAN_x{scale}.pth' | |
try: | |
model.load_weights(weights_path, download=True) | |
print(f"Weights for scale {scale} loaded successfully.") | |
except Exception as e: | |
print(f"Error loading weights for scale {scale}: {e}") | |
model.load_weights(weights_path, download=False) | |
return model | |
model2 = load_model(2) | |
model4 = load_model(4) | |
model8 = load_model(8) | |
def enhance_image(image, scale): | |
try: | |
print(f"Enhancing image with scale {scale}...") | |
start_time = time.time() | |
image_np = np.array(image.convert('RGB')) | |
print(f"Image converted to numpy array: shape {image_np.shape}, dtype {image_np.dtype}") | |
if scale == '2x': | |
result = model2.predict(image_np) | |
elif scale == '4x': | |
result = model4.predict(image_np) | |
else: | |
result = model8.predict(image_np) | |
enhanced_image = Image.fromarray(np.uint8(result)) | |
print(f"Image enhanced in {time.time() - start_time:.2f} seconds") | |
return enhanced_image | |
except Exception as e: | |
print(f"Error enhancing image: {e}") | |
return image | |
def muda_dpi(input_image, dpi): | |
dpi_tuple = (dpi, dpi) | |
image = Image.fromarray(input_image.astype('uint8'), 'RGB') | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') | |
image.save(temp_file, format='JPEG', dpi=dpi_tuple) | |
temp_file.close() | |
return Image.open(temp_file.name) | |
def resize_image(input_image, width, height): | |
image = Image.fromarray(input_image.astype('uint8'), 'RGB') | |
resized_image = image.resize((width, height)) | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') | |
resized_image.save(temp_file, format='JPEG') | |
temp_file.close() | |
return Image.open(temp_file.name) | |
def process_images(image_files, enhance, scale, adjust_dpi, dpi, resize, width, height): | |
processed_images = [] | |
file_paths = [] | |
for image_file in image_files: | |
input_image = np.array(Image.open(image_file).convert('RGB')) | |
original_image = Image.fromarray(input_image.astype('uint8'), 'RGB') | |
if enhance: | |
original_image = enhance_image(original_image, scale) | |
if adjust_dpi: | |
original_image = muda_dpi(np.array(original_image), dpi) | |
if resize: | |
original_image = resize_image(np.array(original_image), width, height) | |
# Sanitize the base filename | |
base_name = os.path.basename(image_file.name) | |
file_name, _ = os.path.splitext(base_name) | |
# Remove any characters that aren't alphanumeric, spaces, underscores, or hyphens | |
file_name = ''.join(e for e in file_name if e.isalnum() or e in (' ', '_', '-')).strip().replace(' ', '_') | |
# Create a final file path without unnecessary suffixes | |
output_path = os.path.join(tempfile.gettempdir(), f"{file_name}.jpg") | |
original_image.save(output_path, format='JPEG') | |
processed_images.append(original_image) | |
file_paths.append(output_path) | |
return processed_images, file_paths | |
iface = gr.Interface( | |
fn=process_images, | |
inputs=[ | |
gr.Files(label="Upload Image Files"), # Use gr.Files for multiple file uploads | |
gr.Checkbox(label="Enhance Images (ESRGAN)"), | |
gr.Radio(['2x', '4x', '8x'], type="value", value='2x', label='Resolution model'), | |
gr.Checkbox(label="Adjust DPI"), | |
gr.Number(label="DPI", value=300), | |
gr.Checkbox(label="Resize"), | |
gr.Number(label="Width", value=512), | |
gr.Number(label="Height", value=512) | |
], | |
outputs=[ | |
gr.Gallery(label="Final Images"), # Use gr.Gallery to display multiple images | |
gr.Files(label="Download Final Images") | |
], | |
title="Multi-Image Enhancer", | |
description="Upload multiple images (.jpg, .png), enhance using AI, adjust DPI, resize, and download the final results." | |
) | |
iface.launch(debug=True) |