File size: 3,851 Bytes
4ae08c7
 
ac2c4c0
 
f850371
0e7f8ec
f850371
9700750
ac2c4c0
f850371
9700750
 
 
 
 
 
 
 
ac2c4c0
f850371
ac2c4c0
9700750
 
 
 
f850371
9700750
f850371
9700750
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ae08c7
f850371
9700750
 
 
 
 
 
 
 
0e7f8ec
9700750
f850371
 
9700750
f850371
4ae08c7
f850371
 
9700750
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import cv2
import numpy as np
import pywt
from skimage import exposure
import gradio as gr
from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt

def process_tiff(file):
    # Read file content
    try:
        img = cv2.imread(file.name, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_GRAYSCALE)
        if img is None:
            raise ValueError("Invalid or corrupted TIFF file")
    except Exception as e:
        raise gr.Error(f"Error reading file: {str(e)}")

    # Normalize to [0, 1]
    img_norm = img.astype(np.float32) / 65535.0

    # Check dimensions for wavelet transform
    if img.shape[0] % 8 != 0 or img.shape[1] % 8 != 0:
        raise gr.Error("Image dimensions must be divisible by 8 for wavelet processing")

    try:
        # Wavelet decomposition
        coeffs = pywt.wavedec2(img_norm, 'bior1.3', level=3)
        cA3, (cH3, cV3, cD3), (cH2, cV2, cD2), (cH1, cV1, cD1) = coeffs

        # Processing coefficients
        cD1 = pywt.threshold(cD1, 0.05*np.max(cD1), 'soft')
        cD2 = pywt.threshold(cD2, 0.07*np.max(cD2), 'soft')
        cH1 *= 1.2
        cV1 *= 1.2

        # Reconstruction
        recon = pywt.waverec2([cA3, (cH3, cV3, cD3), (cH2, cV2, cD2), (cH1, cV1, cD1)], 'bior1.3')
        recon = np.clip(recon, 0, 1)

        # CLAHE
        entropy = -np.sum(recon * np.log2(recon + 1e-7))
        clahe_img = exposure.equalize_adapthist(recon, clip_limit=0.02 if entropy > 7 else 0.05, kernel_size=64)

        # Gamma correction
        p5, p95 = np.percentile(clahe_img, (5, 95))
        gamma = 0.7 if (p95 - p5) < 0.3 else 0.9
        gamma_img = exposure.adjust_gamma(clahe_img, gamma)

        # Sharpening
        sharp = cv2.detailEnhance(
            cv2.cvtColor((gamma_img * 255).astype(np.uint8), cv2.COLOR_GRAY2BGR),
            sigma_s=12, 
            sigma_r=0.15
        )
        sharp = cv2.cvtColor(sharp, cv2.COLOR_BGR2GRAY)

    except Exception as e:
        raise gr.Error(f"Processing error: {str(e)}")

    # Prepare outputs
    original_display = (np.clip(img / np.percentile(img, 99.5), 0, 1) * 255).astype(np.uint8)
    
    # Create histogram plot
    fig, ax = plt.subplots()
    ax.hist(sharp.ravel(), bins=256, range=(0, 255))
    ax.set_title("Enhanced Histogram")
    ax.set_xlabel("Pixel Value")
    ax.set_ylabel("Frequency")
    
    # Convert plot to PIL Image
    buf = BytesIO()
    plt.savefig(buf, format='png', bbox_inches='tight')
    plt.close(fig)
    hist_img = Image.open(buf)

    return original_display, sharp, hist_img

with gr.Blocks(title="MUSICA Enhancement", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🖼️ MUSICA X-ray Image Enhancement")
    gr.Markdown("Upload a 16-bit grayscale TIFF for wavelet-based enhancement")
    
    with gr.Row():
        with gr.Column():
            file_input = gr.File(
                label="Input TIFF",
                file_types=["tif", "tiff"],
                height=100
            )
            submit_btn = gr.Button("Process", variant="primary")
        
        with gr.Column():
            original_output = gr.Image(
                label="Original (Clipped)",
                height=400,
                type="numpy"
            )
            
    with gr.Row():
        enhanced_output = gr.Image(
            label="Enhanced Result",
            type="numpy",
            height=400
        )
        hist_output = gr.Image(
            label="Histogram",
            type="pil",
            height=400
        )

    submit_btn.click(
        process_tiff,
        inputs=file_input,
        outputs=[original_output, enhanced_output, hist_output]
    )

    gr.Examples(
        examples=[["sample.tif"]],
        inputs=file_input,
        outputs=[original_output, enhanced_output, hist_output],
        fn=process_tiff
    )

if __name__ == "__main__":
    demo.launch()