Spaces:
Sleeping
Sleeping
import cv2 | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import pywt | |
from skimage import exposure | |
import gradio as gr | |
from io import BytesIO | |
from PIL import Image | |
# Optional: Print Gradio version for verification | |
print("Gradio version:", gr.__version__) | |
def musica_enhancement(image): | |
# Convert PIL Image to numpy array | |
img = np.array(image) | |
# Debugging: Print image properties | |
print(f"Uploaded image shape: {img.shape}, dtype: {img.dtype}") | |
print(f"Image min: {img.min()}, max: {img.max()}") | |
# Convert RGB to grayscale if necessary | |
if len(img.shape) == 3 and img.shape[2] == 3: # Check for RGB | |
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) | |
print("Converted RGB image to grayscale.") | |
# Ensure the image has 16-bit depth | |
if img.dtype != np.uint16: | |
img = img.astype(np.uint16) # Scale if necessary | |
print("Converted image to 16-bit.") | |
# Normalize to [0, 1] | |
img_norm = img.astype(np.float32) / 65535.0 | |
print(f"Normalized image min: {img_norm.min()}, max: {img_norm.max()}") | |
# Wavelet decomposition and reconstruction | |
coeffs = pywt.wavedec2(img_norm, 'bior1.3', level=3) | |
cA3, (cH3, cV3, cD3), (cH2, cV2, cD2), (cH1, cV1, cD1) = coeffs | |
cD1 = pywt.threshold(cD1, 0.05 * np.max(cD1), mode='soft') | |
cD2 = pywt.threshold(cD2, 0.07 * np.max(cD2), mode='soft') | |
cH1 = cH1 * 1.2 | |
cV1 = cV1 * 1.2 | |
coeffs_enhanced = [cA3, (cH3, cV3, cD3), (cH2, cV2, cD2), (cH1, cV1, cD1)] | |
img_recon = pywt.waverec2(coeffs_enhanced, 'bior1.3') | |
img_recon = np.clip(img_recon, 0, 1) | |
print(f"Reconstructed image min: {img_recon.min()}, max: {img_recon.max()}") | |
# CLAHE | |
img_clahe = exposure.equalize_adapthist(img_recon, clip_limit=0.02, kernel_size=64) | |
# Gamma correction | |
img_gamma = exposure.adjust_gamma(img_clahe, gamma=0.9) | |
# Convert to 8-bit | |
img_gamma_8bit = (img_gamma * 255).astype(np.uint8) | |
# Convert to RGB for better viewing compatibility | |
img_rgb = cv2.cvtColor(img_gamma_8bit, cv2.COLOR_GRAY2RGB) | |
enhanced_image = Image.fromarray(img_rgb) | |
# Create histogram plot | |
plt.figure(figsize=(6, 4)) | |
plt.hist(img_gamma_8bit.ravel(), bins=256, range=(0, 255), color='gray') | |
plt.title('Enhanced Histogram') | |
plt.xlabel('Pixel Intensity') | |
plt.ylabel('Frequency') | |
plt.tight_layout() | |
buf = BytesIO() | |
plt.savefig(buf, format='png') | |
plt.close() | |
buf.seek(0) | |
histogram = Image.open(buf) | |
return enhanced_image, histogram | |
# Define Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Musica Image Enhancement") | |
gr.Markdown( | |
""" | |
Upload a 16-bit TIFF image to enhance it using wavelet decomposition, CLAHE, | |
gamma correction, and edge-preserving sharpening. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image( | |
type="pil", | |
label="Upload 16-bit TIFF Image", | |
) | |
run_button = gr.Button("Enhance Image") | |
with gr.Column(): | |
output_image = gr.Image(type="pil", label="Enhanced Image") | |
output_hist = gr.Image(type="pil", label="Enhanced Histogram") | |
run_button.click( | |
fn=musica_enhancement, | |
inputs=input_image, | |
outputs=[output_image, output_hist], | |
) | |
gr.Examples( | |
examples=[ | |
"examples/sample1.tif", | |
"examples/sample2.tif", | |
"examples/sample3.tif" | |
], | |
inputs=input_image, | |
label="Or try one of these examples", | |
) | |
if __name__ == "__main__": | |
demo.launch() | |