Xray_enhance / app.py
ahmedxeno's picture
Update app.py
f1445e4 verified
raw
history blame
4.45 kB
import cv2
import numpy as np
import matplotlib.pyplot as plt
import pywt
from skimage import exposure
import gradio as gr
from io import BytesIO
from PIL import Image
# Optional: Print Gradio version for verification
print("Gradio version:", gr.__version__)
def musica_enhancement(image):
"""
Enhances a 16-bit TIFF image using wavelet decomposition, CLAHE, gamma correction,
and edge-preserving sharpening.
Args:
image (PIL.Image): Uploaded image.
Returns:
enhanced_image (PIL.Image): Enhanced image.
histogram (PIL.Image): Histogram of the enhanced image.
"""
# Convert PIL Image to numpy array
img = np.array(image)
# Handle different image modes
if image.mode == 'I;16':
# 16-bit image
img = img.astype(np.uint16)
img_norm = img.astype(np.float32) / 65535.0
elif image.mode == 'I':
# 32-bit integer image
img = img.astype(np.int32)
img_norm = img.astype(np.float32) / (2**32 - 1)
elif image.mode in ['RGB', 'RGBA']:
# Convert to grayscale if it's a color image
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
img_norm = img.astype(np.float32) / 255.0
elif image.mode == 'L':
# 8-bit grayscale
img_norm = img.astype(np.float32) / 255.0
else:
raise ValueError(f"Unsupported image mode: {image.mode}")
# 1. Multi-Scale Decomposition (3-level wavelet transform)
coeffs = pywt.wavedec2(img_norm, 'bior1.3', level=3)
cA3, (cH3, cV3, cD3), (cH2, cV2, cD2), (cH1, cV1, cD1) = coeffs
# 2. Adaptive Processing per Sub-band
cD1 = pywt.threshold(cD1, 0.05 * np.max(cD1), mode='soft')
cD2 = pywt.threshold(cD2, 0.07 * np.max(cD2), mode='soft')
cH1 = cH1 * 1.2
cV1 = cV1 * 1.2
# 3. Reconstruct Enhanced Image with clipping
coeffs_enhanced = [cA3, (cH3, cV3, cD3), (cH2, cV2, cD2), (cH1, cV1, cD1)]
img_recon = pywt.waverec2(coeffs_enhanced, 'bior1.3')
img_recon = np.clip(img_recon, 0, 1) # Critical fix
# 4. Adaptive CLAHE
entropy = -np.sum(img_recon * np.log2(img_recon + 1e-7)) # Now safe
clip_limit = 0.02 if entropy > 7 else 0.05
img_clahe = exposure.equalize_adapthist(img_recon, clip_limit=clip_limit, kernel_size=64)
# 5. Gamma correction
p5, p95 = np.percentile(img_clahe, (5, 95))
gamma = 0.7 if (p95 - p5) < 0.3 else 0.9
img_gamma = exposure.adjust_gamma(img_clahe, gamma=gamma)
# 6. Edge-Preserving Sharpening (convert to BGR first)
img_gamma_8bit = (img_gamma * 255).astype(np.uint8)
img_bgr = cv2.cvtColor(img_gamma_8bit, cv2.COLOR_GRAY2BGR) # Convert to 3-channel
img_sharp = cv2.detailEnhance(img_bgr, sigma_s=12, sigma_r=0.15)
img_sharp = cv2.cvtColor(img_sharp, cv2.COLOR_BGR2GRAY) # Convert back to grayscale
# Convert enhanced image to PIL Image
enhanced_image = Image.fromarray(img_sharp)
# Create histogram plot
plt.figure(figsize=(6, 4))
plt.hist(img_sharp.ravel(), bins=256, range=(0, 255), color='gray')
plt.title('Enhanced Histogram')
plt.xlabel('Pixel Intensity')
plt.ylabel('Frequency')
plt.tight_layout()
buf = BytesIO()
plt.savefig(buf, format='png')
plt.close()
buf.seek(0)
histogram = Image.open(buf)
return enhanced_image, histogram
# Define Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Musica Image Enhancement")
gr.Markdown(
"""
Upload a 16-bit TIFF image to enhance it using wavelet decomposition, CLAHE,
gamma correction, and edge-preserving sharpening.
"""
)
with gr.Row():
with gr.Column():
input_image = gr.Image(
type="pil",
label="Upload 16-bit TIFF Image",
tool="editor", # Remove 'source' parameter
)
run_button = gr.Button("Enhance Image")
with gr.Column():
output_image = gr.Image(type="pil", label="Enhanced Image")
output_hist = gr.Image(type="pil", label="Enhanced Histogram")
run_button.click(
fn=musica_enhancement,
inputs=input_image,
outputs=[output_image, output_hist],
)
gr.Examples(
examples=[
"sample1.tif",
"sample2.tif",
"sample3.tif"
],
inputs=input_image,
label="Or try one of these examples",
)
if __name__ == "__main__":
demo.launch()