Spaces:
Sleeping
Sleeping
import cv2 | |
import numpy as np | |
import pywt | |
from skimage import exposure | |
import gradio as gr | |
from PIL import Image | |
from io import BytesIO | |
import matplotlib.pyplot as plt | |
def process_tiff(file): | |
# Read file content | |
try: | |
img = cv2.imread(file.name, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_GRAYSCALE) | |
if img is None: | |
raise ValueError("Invalid or corrupted TIFF file") | |
except Exception as e: | |
raise gr.Error(f"Error reading file: {str(e)}") | |
# Normalize to [0, 1] | |
img_norm = img.astype(np.float32) / 65535.0 | |
# Check dimensions for wavelet transform | |
if img.shape[0] % 8 != 0 or img.shape[1] % 8 != 0: | |
raise gr.Error("Image dimensions must be divisible by 8 for wavelet processing") | |
try: | |
# Wavelet decomposition | |
coeffs = pywt.wavedec2(img_norm, 'bior1.3', level=3) | |
cA3, (cH3, cV3, cD3), (cH2, cV2, cD2), (cH1, cV1, cD1) = coeffs | |
# Processing coefficients | |
cD1 = pywt.threshold(cD1, 0.05*np.max(cD1), 'soft') | |
cD2 = pywt.threshold(cD2, 0.07*np.max(cD2), 'soft') | |
cH1 *= 1.2 | |
cV1 *= 1.2 | |
# Reconstruction | |
recon = pywt.waverec2([cA3, (cH3, cV3, cD3), (cH2, cV2, cD2), (cH1, cV1, cD1)], 'bior1.3') | |
recon = np.clip(recon, 0, 1) | |
# CLAHE | |
entropy = -np.sum(recon * np.log2(recon + 1e-7)) | |
clahe_img = exposure.equalize_adapthist(recon, clip_limit=0.02 if entropy > 7 else 0.05, kernel_size=64) | |
# Gamma correction | |
p5, p95 = np.percentile(clahe_img, (5, 95)) | |
gamma = 0.7 if (p95 - p5) < 0.3 else 0.9 | |
gamma_img = exposure.adjust_gamma(clahe_img, gamma) | |
# Sharpening | |
sharp = cv2.detailEnhance( | |
cv2.cvtColor((gamma_img * 255).astype(np.uint8), cv2.COLOR_GRAY2BGR), | |
sigma_s=12, | |
sigma_r=0.15 | |
) | |
sharp = cv2.cvtColor(sharp, cv2.COLOR_BGR2GRAY) | |
except Exception as e: | |
raise gr.Error(f"Processing error: {str(e)}") | |
# Prepare outputs | |
original_display = (np.clip(img / np.percentile(img, 99.5), 0, 1) * 255).astype(np.uint8) | |
# Create histogram plot | |
fig, ax = plt.subplots() | |
ax.hist(sharp.ravel(), bins=256, range=(0, 255)) | |
ax.set_title("Enhanced Histogram") | |
ax.set_xlabel("Pixel Value") | |
ax.set_ylabel("Frequency") | |
# Convert plot to PIL Image | |
buf = BytesIO() | |
plt.savefig(buf, format='png', bbox_inches='tight') | |
plt.close(fig) | |
hist_img = Image.open(buf) | |
return original_display, sharp, hist_img | |
with gr.Blocks(title="MUSICA Enhancement", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# 🖼️ MUSICA X-ray Image Enhancement") | |
gr.Markdown("Upload a 16-bit grayscale TIFF for wavelet-based enhancement") | |
with gr.Row(): | |
with gr.Column(): | |
file_input = gr.File( | |
label="Input TIFF", | |
file_types=["tif", "tiff"], | |
height=100 | |
) | |
submit_btn = gr.Button("Process", variant="primary") | |
with gr.Column(): | |
original_output = gr.Image( | |
label="Original (Clipped)", | |
height=400, | |
type="numpy" | |
) | |
with gr.Row(): | |
enhanced_output = gr.Image( | |
label="Enhanced Result", | |
type="numpy", | |
height=400 | |
) | |
hist_output = gr.Image( | |
label="Histogram", | |
type="pil", | |
height=400 | |
) | |
submit_btn.click( | |
process_tiff, | |
inputs=file_input, | |
outputs=[original_output, enhanced_output, hist_output] | |
) | |
gr.Examples( | |
examples=[["sample.tif"]], | |
inputs=file_input, | |
outputs=[original_output, enhanced_output, hist_output], | |
fn=process_tiff | |
) | |
if __name__ == "__main__": | |
demo.launch() |