import cv2 import numpy as np import matplotlib.pyplot as plt import pywt from skimage import exposure import gradio as gr from io import BytesIO from PIL import Image # Optional: Print Gradio version for verification print("Gradio version:", gr.__version__) def musica_enhancement(image): """ Enhances a 16-bit TIFF image using wavelet decomposition, CLAHE, gamma correction, and edge-preserving sharpening. Args: image (PIL.Image): Uploaded image. Returns: enhanced_image (PIL.Image): Enhanced image. histogram (PIL.Image): Histogram of the enhanced image. """ # Convert PIL Image to numpy array img = np.array(image) # Handle different image modes if image.mode == 'I;16': # 16-bit image img = img.astype(np.uint16) img_norm = img.astype(np.float32) / 65535.0 elif image.mode == 'I': # 32-bit integer image img = img.astype(np.int32) img_norm = img.astype(np.float32) / (2**32 - 1) elif image.mode in ['RGB', 'RGBA']: # Convert to grayscale if it's a color image img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) img_norm = img.astype(np.float32) / 255.0 elif image.mode == 'L': # 8-bit grayscale img_norm = img.astype(np.float32) / 255.0 else: raise ValueError(f"Unsupported image mode: {image.mode}") # 1. Multi-Scale Decomposition (3-level wavelet transform) coeffs = pywt.wavedec2(img_norm, 'bior1.3', level=3) cA3, (cH3, cV3, cD3), (cH2, cV2, cD2), (cH1, cV1, cD1) = coeffs # 2. Adaptive Processing per Sub-band cD1 = pywt.threshold(cD1, 0.05 * np.max(cD1), mode='soft') cD2 = pywt.threshold(cD2, 0.07 * np.max(cD2), mode='soft') cH1 = cH1 * 1.2 cV1 = cV1 * 1.2 # 3. Reconstruct Enhanced Image with clipping coeffs_enhanced = [cA3, (cH3, cV3, cD3), (cH2, cV2, cD2), (cH1, cV1, cD1)] img_recon = pywt.waverec2(coeffs_enhanced, 'bior1.3') img_recon = np.clip(img_recon, 0, 1) # Critical fix # 4. Adaptive CLAHE entropy = -np.sum(img_recon * np.log2(img_recon + 1e-7)) # Now safe clip_limit = 0.02 if entropy > 7 else 0.05 img_clahe = exposure.equalize_adapthist(img_recon, clip_limit=clip_limit, kernel_size=64) # 5. Gamma correction p5, p95 = np.percentile(img_clahe, (5, 95)) gamma = 0.7 if (p95 - p5) < 0.3 else 0.9 img_gamma = exposure.adjust_gamma(img_clahe, gamma=gamma) # 6. Edge-Preserving Sharpening (convert to BGR first) img_gamma_8bit = (img_gamma * 255).astype(np.uint8) img_bgr = cv2.cvtColor(img_gamma_8bit, cv2.COLOR_GRAY2BGR) # Convert to 3-channel img_sharp = cv2.detailEnhance(img_bgr, sigma_s=12, sigma_r=0.15) img_sharp = cv2.cvtColor(img_sharp, cv2.COLOR_BGR2GRAY) # Convert back to grayscale # Convert enhanced image to PIL Image enhanced_image = Image.fromarray(img_sharp) # Create histogram plot plt.figure(figsize=(6, 4)) plt.hist(img_sharp.ravel(), bins=256, range=(0, 255), color='gray') plt.title('Enhanced Histogram') plt.xlabel('Pixel Intensity') plt.ylabel('Frequency') plt.tight_layout() buf = BytesIO() plt.savefig(buf, format='png') plt.close() buf.seek(0) histogram = Image.open(buf) return enhanced_image, histogram # Define Gradio interface with gr.Blocks() as demo: gr.Markdown("# Musica Image Enhancement") gr.Markdown( """ Upload a 16-bit TIFF image to enhance it using wavelet decomposition, CLAHE, gamma correction, and edge-preserving sharpening. """ ) with gr.Row(): with gr.Column(): input_image = gr.Image( type="pil", label="Upload 16-bit TIFF Image", tool="editor", # Remove 'source' parameter ) run_button = gr.Button("Enhance Image") with gr.Column(): output_image = gr.Image(type="pil", label="Enhanced Image") output_hist = gr.Image(type="pil", label="Enhanced Histogram") run_button.click( fn=musica_enhancement, inputs=input_image, outputs=[output_image, output_hist], ) gr.Examples( examples=[ "sample1.tif", "sample2.tif", "sample3.tif" ], inputs=input_image, label="Or try one of these examples", ) if __name__ == "__main__": demo.launch()