ahmedxeno's picture
Update app.py
9700750 verified
import cv2
import numpy as np
import pywt
from skimage import exposure
import gradio as gr
from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
def process_tiff(file):
# Read file content
try:
img = cv2.imread(file.name, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_GRAYSCALE)
if img is None:
raise ValueError("Invalid or corrupted TIFF file")
except Exception as e:
raise gr.Error(f"Error reading file: {str(e)}")
# Normalize to [0, 1]
img_norm = img.astype(np.float32) / 65535.0
# Check dimensions for wavelet transform
if img.shape[0] % 8 != 0 or img.shape[1] % 8 != 0:
raise gr.Error("Image dimensions must be divisible by 8 for wavelet processing")
try:
# Wavelet decomposition
coeffs = pywt.wavedec2(img_norm, 'bior1.3', level=3)
cA3, (cH3, cV3, cD3), (cH2, cV2, cD2), (cH1, cV1, cD1) = coeffs
# Processing coefficients
cD1 = pywt.threshold(cD1, 0.05*np.max(cD1), 'soft')
cD2 = pywt.threshold(cD2, 0.07*np.max(cD2), 'soft')
cH1 *= 1.2
cV1 *= 1.2
# Reconstruction
recon = pywt.waverec2([cA3, (cH3, cV3, cD3), (cH2, cV2, cD2), (cH1, cV1, cD1)], 'bior1.3')
recon = np.clip(recon, 0, 1)
# CLAHE
entropy = -np.sum(recon * np.log2(recon + 1e-7))
clahe_img = exposure.equalize_adapthist(recon, clip_limit=0.02 if entropy > 7 else 0.05, kernel_size=64)
# Gamma correction
p5, p95 = np.percentile(clahe_img, (5, 95))
gamma = 0.7 if (p95 - p5) < 0.3 else 0.9
gamma_img = exposure.adjust_gamma(clahe_img, gamma)
# Sharpening
sharp = cv2.detailEnhance(
cv2.cvtColor((gamma_img * 255).astype(np.uint8), cv2.COLOR_GRAY2BGR),
sigma_s=12,
sigma_r=0.15
)
sharp = cv2.cvtColor(sharp, cv2.COLOR_BGR2GRAY)
except Exception as e:
raise gr.Error(f"Processing error: {str(e)}")
# Prepare outputs
original_display = (np.clip(img / np.percentile(img, 99.5), 0, 1) * 255).astype(np.uint8)
# Create histogram plot
fig, ax = plt.subplots()
ax.hist(sharp.ravel(), bins=256, range=(0, 255))
ax.set_title("Enhanced Histogram")
ax.set_xlabel("Pixel Value")
ax.set_ylabel("Frequency")
# Convert plot to PIL Image
buf = BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight')
plt.close(fig)
hist_img = Image.open(buf)
return original_display, sharp, hist_img
with gr.Blocks(title="MUSICA Enhancement", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🖼️ MUSICA X-ray Image Enhancement")
gr.Markdown("Upload a 16-bit grayscale TIFF for wavelet-based enhancement")
with gr.Row():
with gr.Column():
file_input = gr.File(
label="Input TIFF",
file_types=["tif", "tiff"],
height=100
)
submit_btn = gr.Button("Process", variant="primary")
with gr.Column():
original_output = gr.Image(
label="Original (Clipped)",
height=400,
type="numpy"
)
with gr.Row():
enhanced_output = gr.Image(
label="Enhanced Result",
type="numpy",
height=400
)
hist_output = gr.Image(
label="Histogram",
type="pil",
height=400
)
submit_btn.click(
process_tiff,
inputs=file_input,
outputs=[original_output, enhanced_output, hist_output]
)
gr.Examples(
examples=[["sample.tif"]],
inputs=file_input,
outputs=[original_output, enhanced_output, hist_output],
fn=process_tiff
)
if __name__ == "__main__":
demo.launch()