import cv2 import numpy as np import matplotlib.pyplot as plt import pywt from skimage import exposure import gradio as gr from io import BytesIO from PIL import Image # Optional: Print Gradio version for verification print("Gradio version:", gr.__version__) def musica_enhancement(image): # Convert PIL Image to numpy array img = np.array(image) # Debugging: Print image properties print(f"Uploaded image shape: {img.shape}, dtype: {img.dtype}") print(f"Image min: {img.min()}, max: {img.max()}") # Convert RGB to grayscale if necessary if len(img.shape) == 3 and img.shape[2] == 3: # Check for RGB img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) print("Converted RGB image to grayscale.") # Ensure the image has 16-bit depth if img.dtype != np.uint16: img = img.astype(np.uint16) # Scale if necessary print("Converted image to 16-bit.") # Normalize to [0, 1] img_norm = img.astype(np.float32) / 65535.0 print(f"Normalized image min: {img_norm.min()}, max: {img_norm.max()}") # Wavelet decomposition and reconstruction coeffs = pywt.wavedec2(img_norm, 'bior1.3', level=3) cA3, (cH3, cV3, cD3), (cH2, cV2, cD2), (cH1, cV1, cD1) = coeffs cD1 = pywt.threshold(cD1, 0.05 * np.max(cD1), mode='soft') cD2 = pywt.threshold(cD2, 0.07 * np.max(cD2), mode='soft') cH1 = cH1 * 1.2 cV1 = cV1 * 1.2 coeffs_enhanced = [cA3, (cH3, cV3, cD3), (cH2, cV2, cD2), (cH1, cV1, cD1)] img_recon = pywt.waverec2(coeffs_enhanced, 'bior1.3') img_recon = np.clip(img_recon, 0, 1) print(f"Reconstructed image min: {img_recon.min()}, max: {img_recon.max()}") # CLAHE img_clahe = exposure.equalize_adapthist(img_recon, clip_limit=0.02, kernel_size=64) # Gamma correction img_gamma = exposure.adjust_gamma(img_clahe, gamma=0.9) # Convert to 8-bit img_gamma_8bit = (img_gamma * 255).astype(np.uint8) # Convert to RGB for better viewing compatibility img_rgb = cv2.cvtColor(img_gamma_8bit, cv2.COLOR_GRAY2RGB) enhanced_image = Image.fromarray(img_rgb) # Create histogram plot plt.figure(figsize=(6, 4)) plt.hist(img_gamma_8bit.ravel(), bins=256, range=(0, 255), color='gray') plt.title('Enhanced Histogram') plt.xlabel('Pixel Intensity') plt.ylabel('Frequency') plt.tight_layout() buf = BytesIO() plt.savefig(buf, format='png') plt.close() buf.seek(0) histogram = Image.open(buf) return enhanced_image, histogram # Define Gradio interface with gr.Blocks() as demo: gr.Markdown("# Musica Image Enhancement") gr.Markdown( """ Upload a 16-bit TIFF image to enhance it using wavelet decomposition, CLAHE, gamma correction, and edge-preserving sharpening. """ ) with gr.Row(): with gr.Column(): input_image = gr.Image( type="pil", label="Upload 16-bit TIFF Image", ) run_button = gr.Button("Enhance Image") with gr.Column(): output_image = gr.Image(type="pil", label="Enhanced Image") output_hist = gr.Image(type="pil", label="Enhanced Histogram") run_button.click( fn=musica_enhancement, inputs=input_image, outputs=[output_image, output_hist], ) gr.Examples( examples=[ "examples/sample1.tif", "examples/sample2.tif", "examples/sample3.tif" ], inputs=input_image, label="Or try one of these examples", ) if __name__ == "__main__": demo.launch()