import cv2 import numpy as np import pywt from skimage import exposure import gradio as gr from PIL import Image from io import BytesIO import matplotlib.pyplot as plt def process_tiff(file): # Read file content try: img = cv2.imread(file.name, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_GRAYSCALE) if img is None: raise ValueError("Invalid or corrupted TIFF file") except Exception as e: raise gr.Error(f"Error reading file: {str(e)}") # Normalize to [0, 1] img_norm = img.astype(np.float32) / 65535.0 # Check dimensions for wavelet transform if img.shape[0] % 8 != 0 or img.shape[1] % 8 != 0: raise gr.Error("Image dimensions must be divisible by 8 for wavelet processing") try: # Wavelet decomposition coeffs = pywt.wavedec2(img_norm, 'bior1.3', level=3) cA3, (cH3, cV3, cD3), (cH2, cV2, cD2), (cH1, cV1, cD1) = coeffs # Processing coefficients cD1 = pywt.threshold(cD1, 0.05*np.max(cD1), 'soft') cD2 = pywt.threshold(cD2, 0.07*np.max(cD2), 'soft') cH1 *= 1.2 cV1 *= 1.2 # Reconstruction recon = pywt.waverec2([cA3, (cH3, cV3, cD3), (cH2, cV2, cD2), (cH1, cV1, cD1)], 'bior1.3') recon = np.clip(recon, 0, 1) # CLAHE entropy = -np.sum(recon * np.log2(recon + 1e-7)) clahe_img = exposure.equalize_adapthist(recon, clip_limit=0.02 if entropy > 7 else 0.05, kernel_size=64) # Gamma correction p5, p95 = np.percentile(clahe_img, (5, 95)) gamma = 0.7 if (p95 - p5) < 0.3 else 0.9 gamma_img = exposure.adjust_gamma(clahe_img, gamma) # Sharpening sharp = cv2.detailEnhance( cv2.cvtColor((gamma_img * 255).astype(np.uint8), cv2.COLOR_GRAY2BGR), sigma_s=12, sigma_r=0.15 ) sharp = cv2.cvtColor(sharp, cv2.COLOR_BGR2GRAY) except Exception as e: raise gr.Error(f"Processing error: {str(e)}") # Prepare outputs original_display = (np.clip(img / np.percentile(img, 99.5), 0, 1) * 255).astype(np.uint8) # Create histogram plot fig, ax = plt.subplots() ax.hist(sharp.ravel(), bins=256, range=(0, 255)) ax.set_title("Enhanced Histogram") ax.set_xlabel("Pixel Value") ax.set_ylabel("Frequency") # Convert plot to PIL Image buf = BytesIO() plt.savefig(buf, format='png', bbox_inches='tight') plt.close(fig) hist_img = Image.open(buf) return original_display, sharp, hist_img with gr.Blocks(title="MUSICA Enhancement", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🖼️ MUSICA X-ray Image Enhancement") gr.Markdown("Upload a 16-bit grayscale TIFF for wavelet-based enhancement") with gr.Row(): with gr.Column(): file_input = gr.File( label="Input TIFF", file_types=["tif", "tiff"], height=100 ) submit_btn = gr.Button("Process", variant="primary") with gr.Column(): original_output = gr.Image( label="Original (Clipped)", height=400, type="numpy" ) with gr.Row(): enhanced_output = gr.Image( label="Enhanced Result", type="numpy", height=400 ) hist_output = gr.Image( label="Histogram", type="pil", height=400 ) submit_btn.click( process_tiff, inputs=file_input, outputs=[original_output, enhanced_output, hist_output] ) gr.Examples( examples=[["sample.tif"]], inputs=file_input, outputs=[original_output, enhanced_output, hist_output], fn=process_tiff ) if __name__ == "__main__": demo.launch()