Xray_enhance / app.py
ahmedxeno's picture
Update app.py
16bd718 verified
import cv2
import numpy as np
import matplotlib.pyplot as plt
import pywt
from skimage import exposure
import gradio as gr
from io import BytesIO
from PIL import Image
# Optional: Print Gradio version for verification
print("Gradio version:", gr.__version__)
def musica_enhancement(image):
# Convert PIL Image to numpy array
img = np.array(image)
# Debugging: Print image properties
print(f"Uploaded image shape: {img.shape}, dtype: {img.dtype}")
print(f"Image min: {img.min()}, max: {img.max()}")
# Convert RGB to grayscale if necessary
if len(img.shape) == 3 and img.shape[2] == 3: # Check for RGB
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
print("Converted RGB image to grayscale.")
# Ensure the image has 16-bit depth
if img.dtype != np.uint16:
img = img.astype(np.uint16) # Scale if necessary
print("Converted image to 16-bit.")
# Normalize to [0, 1]
img_norm = img.astype(np.float32) / 65535.0
print(f"Normalized image min: {img_norm.min()}, max: {img_norm.max()}")
# Wavelet decomposition and reconstruction
coeffs = pywt.wavedec2(img_norm, 'bior1.3', level=3)
cA3, (cH3, cV3, cD3), (cH2, cV2, cD2), (cH1, cV1, cD1) = coeffs
cD1 = pywt.threshold(cD1, 0.05 * np.max(cD1), mode='soft')
cD2 = pywt.threshold(cD2, 0.07 * np.max(cD2), mode='soft')
cH1 = cH1 * 1.2
cV1 = cV1 * 1.2
coeffs_enhanced = [cA3, (cH3, cV3, cD3), (cH2, cV2, cD2), (cH1, cV1, cD1)]
img_recon = pywt.waverec2(coeffs_enhanced, 'bior1.3')
img_recon = np.clip(img_recon, 0, 1)
print(f"Reconstructed image min: {img_recon.min()}, max: {img_recon.max()}")
# CLAHE
img_clahe = exposure.equalize_adapthist(img_recon, clip_limit=0.02, kernel_size=64)
# Gamma correction
img_gamma = exposure.adjust_gamma(img_clahe, gamma=0.9)
# Convert to 8-bit
img_gamma_8bit = (img_gamma * 255).astype(np.uint8)
# Convert to RGB for better viewing compatibility
img_rgb = cv2.cvtColor(img_gamma_8bit, cv2.COLOR_GRAY2RGB)
enhanced_image = Image.fromarray(img_rgb)
# Create histogram plot
plt.figure(figsize=(6, 4))
plt.hist(img_gamma_8bit.ravel(), bins=256, range=(0, 255), color='gray')
plt.title('Enhanced Histogram')
plt.xlabel('Pixel Intensity')
plt.ylabel('Frequency')
plt.tight_layout()
buf = BytesIO()
plt.savefig(buf, format='png')
plt.close()
buf.seek(0)
histogram = Image.open(buf)
return enhanced_image, histogram
# Define Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Musica Image Enhancement")
gr.Markdown(
"""
Upload a 16-bit TIFF image to enhance it using wavelet decomposition, CLAHE,
gamma correction, and edge-preserving sharpening.
"""
)
with gr.Row():
with gr.Column():
input_image = gr.Image(
type="pil",
label="Upload 16-bit TIFF Image",
)
run_button = gr.Button("Enhance Image")
with gr.Column():
output_image = gr.Image(type="pil", label="Enhanced Image")
output_hist = gr.Image(type="pil", label="Enhanced Histogram")
run_button.click(
fn=musica_enhancement,
inputs=input_image,
outputs=[output_image, output_hist],
)
gr.Examples(
examples=[
"examples/sample1.tif",
"examples/sample2.tif",
"examples/sample3.tif"
],
inputs=input_image,
label="Or try one of these examples",
)
if __name__ == "__main__":
demo.launch()