Spaces:
Running
Running
import cv2 | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import pywt | |
from skimage import exposure | |
import gradio as gr | |
from io import BytesIO | |
from PIL import Image | |
# Optional: Print Gradio version for verification | |
print("Gradio version:", gr.__version__) | |
def musica_enhancement(image): | |
""" | |
Enhances a 16-bit TIFF image using wavelet decomposition, CLAHE, gamma correction, | |
and edge-preserving sharpening. | |
Args: | |
image (PIL.Image): Uploaded image. | |
Returns: | |
enhanced_image (PIL.Image): Enhanced image. | |
histogram (PIL.Image): Histogram of the enhanced image. | |
""" | |
# Convert PIL Image to numpy array | |
img = np.array(image) | |
# Handle different image modes | |
if image.mode == 'I;16': | |
# 16-bit image | |
img = img.astype(np.uint16) | |
img_norm = img.astype(np.float32) / 65535.0 | |
elif image.mode == 'I': | |
# 32-bit integer image | |
img = img.astype(np.int32) | |
img_norm = img.astype(np.float32) / (2**32 - 1) | |
elif image.mode in ['RGB', 'RGBA']: | |
# Convert to grayscale if it's a color image | |
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) | |
img_norm = img.astype(np.float32) / 255.0 | |
elif image.mode == 'L': | |
# 8-bit grayscale | |
img_norm = img.astype(np.float32) / 255.0 | |
else: | |
raise ValueError(f"Unsupported image mode: {image.mode}") | |
# 1. Multi-Scale Decomposition (3-level wavelet transform) | |
coeffs = pywt.wavedec2(img_norm, 'bior1.3', level=3) | |
cA3, (cH3, cV3, cD3), (cH2, cV2, cD2), (cH1, cV1, cD1) = coeffs | |
# 2. Adaptive Processing per Sub-band | |
cD1 = pywt.threshold(cD1, 0.05 * np.max(cD1), mode='soft') | |
cD2 = pywt.threshold(cD2, 0.07 * np.max(cD2), mode='soft') | |
cH1 = cH1 * 1.2 | |
cV1 = cV1 * 1.2 | |
# 3. Reconstruct Enhanced Image with clipping | |
coeffs_enhanced = [cA3, (cH3, cV3, cD3), (cH2, cV2, cD2), (cH1, cV1, cD1)] | |
img_recon = pywt.waverec2(coeffs_enhanced, 'bior1.3') | |
img_recon = np.clip(img_recon, 0, 1) # Critical fix | |
# 4. Adaptive CLAHE | |
entropy = -np.sum(img_recon * np.log2(img_recon + 1e-7)) # Now safe | |
clip_limit = 0.02 if entropy > 7 else 0.05 | |
img_clahe = exposure.equalize_adapthist(img_recon, clip_limit=clip_limit, kernel_size=64) | |
# 5. Gamma correction | |
p5, p95 = np.percentile(img_clahe, (5, 95)) | |
gamma = 0.7 if (p95 - p5) < 0.3 else 0.9 | |
img_gamma = exposure.adjust_gamma(img_clahe, gamma=gamma) | |
# 6. Edge-Preserving Sharpening (convert to BGR first) | |
img_gamma_8bit = (img_gamma * 255).astype(np.uint8) | |
img_bgr = cv2.cvtColor(img_gamma_8bit, cv2.COLOR_GRAY2BGR) # Convert to 3-channel | |
img_sharp = cv2.detailEnhance(img_bgr, sigma_s=12, sigma_r=0.15) | |
img_sharp = cv2.cvtColor(img_sharp, cv2.COLOR_BGR2GRAY) # Convert back to grayscale | |
# Convert enhanced image to PIL Image | |
enhanced_image = Image.fromarray(img_sharp) | |
# Create histogram plot | |
plt.figure(figsize=(6, 4)) | |
plt.hist(img_sharp.ravel(), bins=256, range=(0, 255), color='gray') | |
plt.title('Enhanced Histogram') | |
plt.xlabel('Pixel Intensity') | |
plt.ylabel('Frequency') | |
plt.tight_layout() | |
buf = BytesIO() | |
plt.savefig(buf, format='png') | |
plt.close() | |
buf.seek(0) | |
histogram = Image.open(buf) | |
return enhanced_image, histogram | |
# Define Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Musica Image Enhancement") | |
gr.Markdown( | |
""" | |
Upload a 16-bit TIFF image to enhance it using wavelet decomposition, CLAHE, | |
gamma correction, and edge-preserving sharpening. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image( | |
type="pil", | |
label="Upload 16-bit TIFF Image", | |
tool="editor", # Remove 'source' parameter | |
) | |
run_button = gr.Button("Enhance Image") | |
with gr.Column(): | |
output_image = gr.Image(type="pil", label="Enhanced Image") | |
output_hist = gr.Image(type="pil", label="Enhanced Histogram") | |
run_button.click( | |
fn=musica_enhancement, | |
inputs=input_image, | |
outputs=[output_image, output_hist], | |
) | |
gr.Examples( | |
examples=[ | |
"sample1.tif", | |
"sample2.tif", | |
"sample3.tif" | |
], | |
inputs=input_image, | |
label="Or try one of these examples", | |
) | |
if __name__ == "__main__": | |
demo.launch() | |