ahmedxeno commited on
Commit
9700750
·
verified ·
1 Parent(s): 4ef5297

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -52
app.py CHANGED
@@ -1,75 +1,125 @@
1
  import cv2
2
  import numpy as np
3
- import matplotlib.pyplot as plt
4
  import pywt
5
  from skimage import exposure
6
  import gradio as gr
7
  from PIL import Image
8
  from io import BytesIO
 
9
 
10
  def process_tiff(file):
11
- # Load 16-bit TIFF
12
- img = cv2.imread(file.name, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_GRAYSCALE)
13
- if img is None:
14
- raise gr.Error("Could not read TIFF file. Ensure it's 16-bit grayscale.")
15
-
 
 
 
16
  # Normalize to [0, 1]
17
  img_norm = img.astype(np.float32) / 65535.0
18
 
19
- # Wavelet decomposition
 
 
 
20
  try:
 
21
  coeffs = pywt.wavedec2(img_norm, 'bior1.3', level=3)
22
- except ValueError as e:
23
- raise gr.Error(f"Image dimensions must be divisible by 8. {str(e)}")
24
-
25
- # Processing coefficients (same as original)
26
- cA3, (cH3, cV3, cD3), (cH2, cV2, cD2), (cH1, cV1, cD1) = coeffs
27
- cD1 = pywt.threshold(cD1, 0.05*np.max(cD1), 'soft')
28
- cD2 = pywt.threshold(cD2, 0.07*np.max(cD2), 'soft')
29
- cH1 *= 1.2; cV1 *= 1.2
30
-
31
- # Reconstruction
32
- recon = pywt.waverec2([cA3, (cH3, cV3, cD3), (cH2, cV2, cD2), (cH1, cV1, cD1)], 'bior1.3')
33
- recon = np.clip(recon, 0, 1)
34
-
35
- # CLAHE
36
- entropy = -np.sum(recon * np.log2(recon + 1e-7))
37
- clahe_img = exposure.equalize_adapthist(recon, clip_limit=0.02 if entropy >7 else 0.05, kernel_size=64)
38
-
39
- # Gamma correction
40
- p5, p95 = np.percentile(clahe_img, (5, 95))
41
- gamma = 0.7 if (p95 - p5) < 0.3 else 0.9
42
- gamma_img = exposure.adjust_gamma(clahe_img, gamma)
43
-
44
- # Sharpening
45
- sharp = cv2.detailEnhance(cv2.cvtColor((gamma_img*255).astype(np.uint8), cv2.COLOR_GRAY2BGR),
46
- sigma_s=12, sigma_r=0.15)
47
- sharp = cv2.cvtColor(sharp, cv2.COLOR_BGR2GRAY)
 
 
 
 
 
48
 
49
  # Prepare outputs
50
- original_display = (np.clip(img/np.percentile(img, 99.5), 0, 1)*255).astype(np.uint8)
 
 
 
 
 
 
 
51
 
 
52
  buf = BytesIO()
53
- plt.hist(sharp.ravel(), bins=256, range=(0, 255))
54
- plt.title("Enhanced Histogram")
55
  plt.savefig(buf, format='png', bbox_inches='tight')
56
- plt.close()
57
  hist_img = Image.open(buf)
58
 
59
  return original_display, sharp, hist_img
60
 
61
- # Create Gradio interface
62
- interface = gr.Interface(
63
- fn=process_tiff,
64
- inputs=gr.File(label="Upload TIFF", file_types=[".tif", ".tiff"]),
65
- outputs=[
66
- gr.Image(label="Original (Clipped)"),
67
- gr.Image(label="Enhanced Image"),
68
- gr.Image(label="Histogram")
69
- ],
70
- title="MUSICA Image Enhancement",
71
- description="Upload 16-bit grayscale TIFF for enhancement using wavelet-based processing",
72
- allow_flagging="never"
73
- )
74
-
75
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import cv2
2
  import numpy as np
 
3
  import pywt
4
  from skimage import exposure
5
  import gradio as gr
6
  from PIL import Image
7
  from io import BytesIO
8
+ import matplotlib.pyplot as plt
9
 
10
  def process_tiff(file):
11
+ # Read file content
12
+ try:
13
+ img = cv2.imread(file.name, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_GRAYSCALE)
14
+ if img is None:
15
+ raise ValueError("Invalid or corrupted TIFF file")
16
+ except Exception as e:
17
+ raise gr.Error(f"Error reading file: {str(e)}")
18
+
19
  # Normalize to [0, 1]
20
  img_norm = img.astype(np.float32) / 65535.0
21
 
22
+ # Check dimensions for wavelet transform
23
+ if img.shape[0] % 8 != 0 or img.shape[1] % 8 != 0:
24
+ raise gr.Error("Image dimensions must be divisible by 8 for wavelet processing")
25
+
26
  try:
27
+ # Wavelet decomposition
28
  coeffs = pywt.wavedec2(img_norm, 'bior1.3', level=3)
29
+ cA3, (cH3, cV3, cD3), (cH2, cV2, cD2), (cH1, cV1, cD1) = coeffs
30
+
31
+ # Processing coefficients
32
+ cD1 = pywt.threshold(cD1, 0.05*np.max(cD1), 'soft')
33
+ cD2 = pywt.threshold(cD2, 0.07*np.max(cD2), 'soft')
34
+ cH1 *= 1.2
35
+ cV1 *= 1.2
36
+
37
+ # Reconstruction
38
+ recon = pywt.waverec2([cA3, (cH3, cV3, cD3), (cH2, cV2, cD2), (cH1, cV1, cD1)], 'bior1.3')
39
+ recon = np.clip(recon, 0, 1)
40
+
41
+ # CLAHE
42
+ entropy = -np.sum(recon * np.log2(recon + 1e-7))
43
+ clahe_img = exposure.equalize_adapthist(recon, clip_limit=0.02 if entropy > 7 else 0.05, kernel_size=64)
44
+
45
+ # Gamma correction
46
+ p5, p95 = np.percentile(clahe_img, (5, 95))
47
+ gamma = 0.7 if (p95 - p5) < 0.3 else 0.9
48
+ gamma_img = exposure.adjust_gamma(clahe_img, gamma)
49
+
50
+ # Sharpening
51
+ sharp = cv2.detailEnhance(
52
+ cv2.cvtColor((gamma_img * 255).astype(np.uint8), cv2.COLOR_GRAY2BGR),
53
+ sigma_s=12,
54
+ sigma_r=0.15
55
+ )
56
+ sharp = cv2.cvtColor(sharp, cv2.COLOR_BGR2GRAY)
57
+
58
+ except Exception as e:
59
+ raise gr.Error(f"Processing error: {str(e)}")
60
 
61
  # Prepare outputs
62
+ original_display = (np.clip(img / np.percentile(img, 99.5), 0, 1) * 255).astype(np.uint8)
63
+
64
+ # Create histogram plot
65
+ fig, ax = plt.subplots()
66
+ ax.hist(sharp.ravel(), bins=256, range=(0, 255))
67
+ ax.set_title("Enhanced Histogram")
68
+ ax.set_xlabel("Pixel Value")
69
+ ax.set_ylabel("Frequency")
70
 
71
+ # Convert plot to PIL Image
72
  buf = BytesIO()
 
 
73
  plt.savefig(buf, format='png', bbox_inches='tight')
74
+ plt.close(fig)
75
  hist_img = Image.open(buf)
76
 
77
  return original_display, sharp, hist_img
78
 
79
+ with gr.Blocks(title="MUSICA Enhancement", theme=gr.themes.Soft()) as demo:
80
+ gr.Markdown("# 🖼️ MUSICA X-ray Image Enhancement")
81
+ gr.Markdown("Upload a 16-bit grayscale TIFF for wavelet-based enhancement")
82
+
83
+ with gr.Row():
84
+ with gr.Column():
85
+ file_input = gr.File(
86
+ label="Input TIFF",
87
+ file_types=["tif", "tiff"],
88
+ height=100
89
+ )
90
+ submit_btn = gr.Button("Process", variant="primary")
91
+
92
+ with gr.Column():
93
+ original_output = gr.Image(
94
+ label="Original (Clipped)",
95
+ height=400,
96
+ type="numpy"
97
+ )
98
+
99
+ with gr.Row():
100
+ enhanced_output = gr.Image(
101
+ label="Enhanced Result",
102
+ type="numpy",
103
+ height=400
104
+ )
105
+ hist_output = gr.Image(
106
+ label="Histogram",
107
+ type="pil",
108
+ height=400
109
+ )
110
+
111
+ submit_btn.click(
112
+ process_tiff,
113
+ inputs=file_input,
114
+ outputs=[original_output, enhanced_output, hist_output]
115
+ )
116
+
117
+ gr.Examples(
118
+ examples=[["sample.tif"]],
119
+ inputs=file_input,
120
+ outputs=[original_output, enhanced_output, hist_output],
121
+ fn=process_tiff
122
+ )
123
+
124
+ if __name__ == "__main__":
125
+ demo.launch()