yassonee commited on
Commit
952e5d1
·
verified ·
1 Parent(s): 8d54860

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -81
app.py CHANGED
@@ -2,7 +2,6 @@ import streamlit as st
2
  from transformers import pipeline
3
  from PIL import Image, ImageDraw
4
  import numpy as np
5
- from PIL import ImageColor
6
  import colorsys
7
 
8
  st.set_page_config(
@@ -21,29 +20,22 @@ st.markdown("""
21
  padding-top: 0 !important;
22
  padding-bottom: 0 !important;
23
  max-width: 1400px !important;
24
- margin: 0 auto !important;
25
  }
26
 
27
- .main-container {
28
- display: flex;
29
- gap: 1rem;
30
- padding: 1rem;
31
  background: white;
 
32
  border-radius: 10px;
33
  box-shadow: 0 2px 4px rgba(0,0,0,0.1);
34
- margin: 1rem;
35
- }
36
-
37
- .upload-section {
38
- flex: 1;
39
- padding: 1rem;
40
- border-radius: 8px;
41
- background: #f8f9fa;
42
  }
43
 
44
- .result-section {
45
- flex: 2;
46
- padding: 1rem;
 
 
47
  }
48
 
49
  .result-box {
@@ -78,7 +70,7 @@ st.markdown("""
78
  }
79
 
80
  .stButton > button {
81
- width: 100%;
82
  background-color: #0066cc !important;
83
  color: white !important;
84
  border: none !important;
@@ -92,16 +84,12 @@ st.markdown("""
92
  transform: translateY(-1px);
93
  }
94
 
95
- #MainMenu, footer, header {
96
  display: none !important;
97
  }
98
 
99
  /* Hide deprecation warning */
100
- [data-testid="stExpander"] {
101
- display: none !important;
102
- }
103
-
104
- .element-container:has(>.stAlert) {
105
  display: none !important;
106
  }
107
  </style>
@@ -119,11 +107,11 @@ def load_models():
119
  def translate_label(label):
120
  translations = {
121
  "fracture": "Knochenbruch",
122
- "no fracture": "Kein Bruch",
123
  "normal": "Normal",
124
  "abnormal": "Auffällig",
125
  "F1": "Knochenbruch",
126
- "NF": "Kein Bruch"
127
  }
128
  return translations.get(label.lower(), label)
129
 
@@ -131,27 +119,22 @@ def create_heatmap_overlay(image, box, score):
131
  overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
132
  draw = ImageDraw.Draw(overlay)
133
 
134
- # Create gradient colors based on confidence
135
  def get_heatmap_color(value):
136
- # Convert to HSV for better control
137
- hue = (1 - value) * 0.3 # 0.3 = reddish, 0 = red
138
  saturation = 0.8
139
  value = 0.9
140
- # Convert back to RGB
141
  rgb = colorsys.hsv_to_rgb(hue, saturation, value)
142
  return tuple(int(x * 255) for x in rgb)
143
 
144
- # Draw the heatmap with gradient
145
  x1, y1 = box['xmin'], box['ymin']
146
  x2, y2 = box['xmax'], box['ymax']
147
 
148
  steps = 20
149
  for i in range(steps):
150
- alpha = int(255 * (1 - i/steps) * 0.6) # Gradient transparency
151
  color = get_heatmap_color(score)
152
  rect_color = color + (alpha,)
153
 
154
- # Create shrinking rectangles for gradient effect
155
  shrink = i * ((x2-x1)/(steps*2))
156
  draw.rectangle([x1+shrink, y1+shrink, x2-shrink, y2-shrink],
157
  fill=rect_color)
@@ -159,19 +142,16 @@ def create_heatmap_overlay(image, box, score):
159
  return overlay
160
 
161
  def draw_boxes(image, predictions):
162
- # Create a copy of the image to work with
163
  result_image = image.copy().convert('RGBA')
164
 
165
  for pred in predictions:
166
  box = pred['box']
167
  score = pred['score']
168
- label = f"{translate_label(pred['label'])} ({score:.2%})"
169
 
170
- # Create and combine heatmap overlay
171
  heatmap = create_heatmap_overlay(image, box, score)
172
  result_image = Image.alpha_composite(result_image, heatmap)
173
 
174
- # Draw border and label
175
  draw = ImageDraw.Draw(result_image)
176
  draw.rectangle(
177
  [(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
@@ -179,7 +159,6 @@ def draw_boxes(image, predictions):
179
  width=2
180
  )
181
 
182
- # Add label with background
183
  text_bbox = draw.textbbox((box['xmin'], box['ymin']-20), label)
184
  draw.rectangle(text_bbox, fill="#000000AA")
185
  draw.text((box['xmin'], box['ymin']-20), label, fill="white")
@@ -189,80 +168,83 @@ def draw_boxes(image, predictions):
189
  def main():
190
  models = load_models()
191
 
192
- # Initialize session state
193
- if 'analyzed' not in st.session_state:
194
- st.session_state.analyzed = False
195
-
196
- # Main container
197
- st.markdown('<div class="main-container">', unsafe_allow_html=True)
198
-
199
- # Upload section
200
- st.markdown('<div class="upload-section">', unsafe_allow_html=True)
201
- st.markdown("### 📤 Röntgenbild Upload")
202
- uploaded_file = st.file_uploader("", type=['png', 'jpg', 'jpeg'])
203
-
204
- conf_threshold = st.slider(
205
- "Konfidenzschwelle",
206
- min_value=0.0, max_value=1.0,
207
- value=0.60, step=0.05
208
- )
209
-
210
- analyze_button = st.button("Analysieren")
211
- st.markdown('</div>', unsafe_allow_html=True)
212
-
213
- # Results section
214
- st.markdown('<div class="result-section">', unsafe_allow_html=True)
215
-
216
- if uploaded_file and analyze_button:
217
- st.session_state.analyzed = True
218
 
219
- with st.spinner("Analysiere Bild..."):
 
 
 
 
 
 
 
 
 
 
 
220
  image = Image.open(uploaded_file)
221
 
 
 
222
  col1, col2 = st.columns(2)
223
 
224
  with col1:
225
- st.markdown("### 🎯 KI-Analyse")
226
 
227
- # KnochenWächter results
228
- st.markdown("#### 🛡️ KnochenWächter")
229
- predictions = models["KnochenWächter"](image)
230
- for pred in predictions:
231
  if pred['score'] >= conf_threshold:
 
 
 
232
  st.markdown(f"""
233
  <div class="result-box">
234
- <span style="color: {'#0066cc' if pred['score'] > 0.7 else '#ffa500'}; font-weight: 500;">
235
  {pred['score']:.1%}
236
  </span> - {translate_label(pred['label'])}
237
  </div>
238
  """, unsafe_allow_html=True)
239
 
240
- # RöntgenMeister results
241
- st.markdown("#### 🎓 RöntgenMeister")
242
- predictions = models["RöntgenMeister"](image)
243
- for pred in predictions:
244
  if pred['score'] >= conf_threshold:
 
245
  st.markdown(f"""
246
  <div class="result-box">
247
- <span style="color: {'#0066cc' if pred['score'] > 0.7 else '#ffa500'}; font-weight: 500;">
248
  {pred['score']:.1%}
249
  </span> - {translate_label(pred['label'])}
250
  </div>
251
  """, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
  with col2:
254
- st.markdown("### 🔍 Visualisierung")
255
  predictions = models["KnochenAuge"](image)
256
  filtered_preds = [p for p in predictions if p['score'] >= conf_threshold]
257
 
258
- if filtered_preds:
259
  result_image = draw_boxes(image, filtered_preds)
260
  st.image(result_image, use_container_width=True)
261
  else:
262
  st.image(image, use_container_width=True)
263
-
264
- st.markdown('</div>', unsafe_allow_html=True)
265
- st.markdown('</div>', unsafe_allow_html=True)
266
 
267
  if __name__ == "__main__":
268
  main()
 
2
  from transformers import pipeline
3
  from PIL import Image, ImageDraw
4
  import numpy as np
 
5
  import colorsys
6
 
7
  st.set_page_config(
 
20
  padding-top: 0 !important;
21
  padding-bottom: 0 !important;
22
  max-width: 1400px !important;
 
23
  }
24
 
25
+ .upload-container {
 
 
 
26
  background: white;
27
+ padding: 1.5rem;
28
  border-radius: 10px;
29
  box-shadow: 0 2px 4px rgba(0,0,0,0.1);
30
+ margin-bottom: 1rem;
31
+ text-align: center;
 
 
 
 
 
 
32
  }
33
 
34
+ .results-container {
35
+ background: white;
36
+ padding: 1.5rem;
37
+ border-radius: 10px;
38
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
39
  }
40
 
41
  .result-box {
 
70
  }
71
 
72
  .stButton > button {
73
+ width: 200px;
74
  background-color: #0066cc !important;
75
  color: white !important;
76
  border: none !important;
 
84
  transform: translateY(-1px);
85
  }
86
 
87
+ #MainMenu, footer, header, [data-testid="stToolbar"] {
88
  display: none !important;
89
  }
90
 
91
  /* Hide deprecation warning */
92
+ [data-testid="stExpander"], .element-container:has(>.stAlert) {
 
 
 
 
93
  display: none !important;
94
  }
95
  </style>
 
107
  def translate_label(label):
108
  translations = {
109
  "fracture": "Knochenbruch",
110
+ "no fracture": "Kein Knochenbruch",
111
  "normal": "Normal",
112
  "abnormal": "Auffällig",
113
  "F1": "Knochenbruch",
114
+ "NF": "Kein Knochenbruch"
115
  }
116
  return translations.get(label.lower(), label)
117
 
 
119
  overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
120
  draw = ImageDraw.Draw(overlay)
121
 
 
122
  def get_heatmap_color(value):
123
+ hue = (1 - value) * 0.3
 
124
  saturation = 0.8
125
  value = 0.9
 
126
  rgb = colorsys.hsv_to_rgb(hue, saturation, value)
127
  return tuple(int(x * 255) for x in rgb)
128
 
 
129
  x1, y1 = box['xmin'], box['ymin']
130
  x2, y2 = box['xmax'], box['ymax']
131
 
132
  steps = 20
133
  for i in range(steps):
134
+ alpha = int(255 * (1 - i/steps) * 0.6)
135
  color = get_heatmap_color(score)
136
  rect_color = color + (alpha,)
137
 
 
138
  shrink = i * ((x2-x1)/(steps*2))
139
  draw.rectangle([x1+shrink, y1+shrink, x2-shrink, y2-shrink],
140
  fill=rect_color)
 
142
  return overlay
143
 
144
  def draw_boxes(image, predictions):
 
145
  result_image = image.copy().convert('RGBA')
146
 
147
  for pred in predictions:
148
  box = pred['box']
149
  score = pred['score']
150
+ label = f"{translate_label(pred['label'])} ({score:.1%})"
151
 
 
152
  heatmap = create_heatmap_overlay(image, box, score)
153
  result_image = Image.alpha_composite(result_image, heatmap)
154
 
 
155
  draw = ImageDraw.Draw(result_image)
156
  draw.rectangle(
157
  [(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
 
159
  width=2
160
  )
161
 
 
162
  text_bbox = draw.textbbox((box['xmin'], box['ymin']-20), label)
163
  draw.rectangle(text_bbox, fill="#000000AA")
164
  draw.text((box['xmin'], box['ymin']-20), label, fill="white")
 
168
  def main():
169
  models = load_models()
170
 
171
+ with st.container():
172
+ st.write("### 📤 Röntgenbild hochladen")
173
+ uploaded_file = st.file_uploader("", type=['png', 'jpg', 'jpeg'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
+ col1, col2 = st.columns([2, 1])
176
+ with col1:
177
+ conf_threshold = st.slider(
178
+ "Konfidenzschwelle",
179
+ min_value=0.0, max_value=1.0,
180
+ value=0.60, step=0.05
181
+ )
182
+ with col2:
183
+ analyze_button = st.button("Analysieren")
184
+
185
+ if uploaded_file and analyze_button:
186
+ with st.spinner("Bild wird analysiert..."):
187
  image = Image.open(uploaded_file)
188
 
189
+ st.write("### 🔍 Analyse Ergebnisse")
190
+
191
  col1, col2 = st.columns(2)
192
 
193
  with col1:
194
+ st.write("#### 🤖 KI-Diagnose")
195
 
196
+ # KnochenWächter
197
+ predictions_watcher = models["KnochenWächter"](image)
198
+ has_fracture = False
199
+ for pred in predictions_watcher:
200
  if pred['score'] >= conf_threshold:
201
+ confidence_color = '#0066cc' if pred['score'] > 0.7 else '#ffa500'
202
+ if 'fracture' in pred['label'].lower():
203
+ has_fracture = True
204
  st.markdown(f"""
205
  <div class="result-box">
206
+ <span style="color: {confidence_color}; font-weight: 500;">
207
  {pred['score']:.1%}
208
  </span> - {translate_label(pred['label'])}
209
  </div>
210
  """, unsafe_allow_html=True)
211
 
212
+ # RöntgenMeister
213
+ predictions_master = models["RöntgenMeister"](image)
214
+ for pred in predictions_master:
 
215
  if pred['score'] >= conf_threshold:
216
+ confidence_color = '#0066cc' if pred['score'] > 0.7 else '#ffa500'
217
  st.markdown(f"""
218
  <div class="result-box">
219
+ <span style="color: {confidence_color}; font-weight: 500;">
220
  {pred['score']:.1%}
221
  </span> - {translate_label(pred['label'])}
222
  </div>
223
  """, unsafe_allow_html=True)
224
+
225
+ # Calculate and display fracture probability
226
+ fracture_prob = max((p['score'] for p in predictions_watcher
227
+ if 'fracture' in p['label'].lower()), default=0)
228
+ no_fracture_prob = 1 - fracture_prob
229
+
230
+ st.write("#### 📊 Wahrscheinlichkeit")
231
+ st.markdown(f"""
232
+ <div class="result-box">
233
+ Knochenbruch: <strong>{fracture_prob:.1%}</strong><br>
234
+ Kein Knochenbruch: <strong>{no_fracture_prob:.1%}</strong>
235
+ </div>
236
+ """, unsafe_allow_html=True)
237
 
238
  with col2:
239
+ st.write("#### 🎯 Visualisierung")
240
  predictions = models["KnochenAuge"](image)
241
  filtered_preds = [p for p in predictions if p['score'] >= conf_threshold]
242
 
243
+ if filtered_preds and has_fracture:
244
  result_image = draw_boxes(image, filtered_preds)
245
  st.image(result_image, use_container_width=True)
246
  else:
247
  st.image(image, use_container_width=True)
 
 
 
248
 
249
  if __name__ == "__main__":
250
  main()