yassonee commited on
Commit
1a77416
·
verified ·
1 Parent(s): e237c4f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +195 -182
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import streamlit as st
2
  from transformers import pipeline
3
  from PIL import Image, ImageDraw
4
- import torch
 
5
 
6
  st.set_page_config(
7
  page_title="Fraktur Detektion",
@@ -11,98 +12,85 @@ st.set_page_config(
11
 
12
  st.markdown("""
13
  <style>
14
- /* Reset et base */
15
  .stApp {
16
- background-color: var(--background-color) !important;
17
- padding: 0 !important;
18
- overflow: hidden !important;
19
  }
20
 
21
- /* Variables de thème */
22
- [data-theme="light"] {
23
- --background-color: #ffffff;
24
- --text-color: #1f2937;
25
- --border-color: #e5e7eb;
26
- --secondary-bg: #f3f4f6;
27
- }
28
-
29
- [data-theme="dark"] {
30
- --background-color: #1f2937;
31
- --text-color: #f3f4f6;
32
- --border-color: #4b5563;
33
- --secondary-bg: #374151;
34
- }
35
-
36
- /* Layout principal */
37
  .block-container {
38
- padding: 0.5rem !important;
39
- max-width: 100% !important;
 
40
  }
41
 
42
- /* Contrôles et upload */
43
- .uploadedFile {
44
- border: 1px dashed var(--border-color);
45
- border-radius: 0.375rem;
46
- padding: 0.25rem;
47
- background: var(--secondary-bg);
 
48
  }
49
 
50
- /* Ajustement des colonnes */
51
- [data-testid="column"] {
52
- padding: 0 0.5rem !important;
 
 
53
  }
54
 
55
- /* Images adaptatives */
56
- .stImage > img {
57
- width: 100% !important;
58
- height: auto !important;
59
- max-height: 400px !important;
60
- object-fit: contain !important;
61
  }
62
 
63
- /* Résultats */
64
- .result-box {
65
- padding: 0.375rem;
66
- border-radius: 0.375rem;
67
- margin: 0.25rem 0;
68
- background: var(--secondary-bg);
69
- border: 1px solid var(--border-color);
70
- color: var(--text-color);
71
  }
72
 
73
- /* Titres */
74
- h2, h3 {
75
- margin: 0 !important;
76
- padding: 0.5rem 0 !important;
77
- font-size: 1rem !important;
78
- color: var(--text-color) !important;
79
  }
80
 
81
- /* Nettoyage des éléments inutiles */
82
- #MainMenu, footer, header, .viewerBadge_container__1QSob, .stDeployButton {
83
- display: none !important;
 
 
84
  }
85
 
86
- /* Ajustements espacement */
87
- div[data-testid="stVerticalBlock"] {
88
- gap: 0.5rem !important;
89
  }
90
 
91
- .element-container {
92
- margin: 0.25rem 0 !important;
 
 
 
 
 
 
93
  }
94
 
95
- /* Éléments spécifiques */
96
- .high-confidence {
97
- color: #22c55e !important;
98
  }
99
 
100
- .medium-confidence {
101
- color: #eab308 !important;
102
  }
103
 
104
- .low-confidence {
105
- color: #dc2626 !important;
 
106
  }
107
  </style>
108
  """, unsafe_allow_html=True)
@@ -119,7 +107,7 @@ def load_models():
119
  def translate_label(label):
120
  translations = {
121
  "fracture": "Knochenbruch",
122
- "no fracture": "Kein Bruch",
123
  "normal": "Normal",
124
  "abnormal": "Auffällig",
125
  "F1": "Knochenbruch",
@@ -127,149 +115,174 @@ def translate_label(label):
127
  }
128
  return translations.get(label.lower(), label)
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  def draw_boxes(image, predictions):
131
- draw = ImageDraw.Draw(image)
132
- for pred in predictions:
 
 
 
133
  box = pred['box']
134
  score = pred['score']
135
 
136
- # Calcul de la température simulée basée sur le score
137
- # Score 1.0 -> 39°C (forte probabilité = "plus chaud")
138
- # Score 0.6 -> 36.5°C (seuil minimum = "normal")
139
- temp = 36.5 + (score - 0.6) * (39 - 36.5) / 0.4
140
 
141
- # Couleur basée sur le score
142
- if score > 0.8:
143
- color = "#dc2626" # rouge pour haute confiance
144
- elif score > 0.7:
145
- color = "#ea580c" # orange pour confiance moyenne-haute
146
- else:
147
- color = "#eab308" # jaune pour confiance moyenne
148
 
149
- # Dessiner la boîte
150
- draw.rectangle(
151
- [(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
152
- outline=color,
153
- width=2
154
  )
 
155
 
156
- # Créer le label avec température
157
- label = f"{translate_label(pred['label'])} ({score:.1%} • {temp:.1f}°C)"
158
-
159
- # Fond pour le texte
160
- text_bbox = draw.textbbox((box['xmin'], box['ymin']-15), label)
161
- draw.rectangle(text_bbox, fill=color)
162
-
163
- # Texte
164
  draw.text(
165
- (box['xmin'], box['ymin']-15),
166
  label,
167
- fill="white"
 
 
168
  )
169
 
170
- return image
171
 
172
  def main():
173
  models = load_models()
174
-
175
- # Disposition en deux colonnes principales
176
- col1, col2 = st.columns([1, 2])
177
-
178
- with col1:
179
- st.markdown("### 📤 Röntgenbild Upload")
180
- uploaded_file = st.file_uploader("", type=['png', 'jpg', 'jpeg'])
181
 
182
- if uploaded_file:
 
183
  conf_threshold = st.slider(
184
  "Konfidenzschwelle",
185
  min_value=0.0, max_value=1.0,
186
- value=0.60, step=0.05
 
187
  )
 
 
188
 
189
- with col2:
190
- if uploaded_file:
191
  image = Image.open(uploaded_file)
 
192
 
193
- st.markdown("### 🔍 Meinung der KI-Experten")
 
 
194
 
195
- # KnochenAuge Analysis (Localisation)
196
- st.markdown("#### 👁️ Das KnochenAuge - Lokalisation")
197
- predictions = models["KnochenAuge"](image)
198
- filtered_preds = [p for p in predictions if p['score'] >= conf_threshold]
 
199
 
200
- if filtered_preds:
201
- result_image = image.copy()
202
- result_image = draw_boxes(result_image, filtered_preds)
203
- st.image(result_image, use_container_width=True)
204
- else:
205
- st.image(image, use_container_width=True)
206
- st.info("Keine signifikanten Auffälligkeiten gefunden.")
207
 
208
- # Other Models Analysis
209
- st.markdown("#### 🎯 KI-Analyse")
210
- col_left, col_right = st.columns(2)
211
-
212
- def get_score_class(score):
213
- if score > 0.8:
214
- return "high-confidence"
215
- elif score > 0.7:
216
- return "medium-confidence"
217
- return "low-confidence"
218
-
219
- with col_left:
220
- st.markdown("**🛡️ Der KnochenWächter**")
221
- predictions = models["KnochenWächter"](image)
222
- has_predictions = False
223
- for pred in predictions:
224
- if pred['score'] >= conf_threshold:
225
- has_predictions = True
226
- score_class = get_score_class(pred['score'])
227
- st.markdown(f"""
228
- <div class='result-box'>
229
- <span class='{score_class}' style='font-weight: 500;'>
230
- {pred['score']:.1%}
231
- </span> - {translate_label(pred['label'])}
232
- </div>
233
- """, unsafe_allow_html=True)
234
- if not has_predictions:
235
- st.info("Keine ausreichend sicheren Vorhersagen.")
236
-
237
- with col_right:
238
- st.markdown("**🎓 Der RöntgenMeister**")
239
- predictions = models["RöntgenMeister"](image)
240
- has_predictions = False
241
- for pred in predictions:
242
- if pred['score'] >= conf_threshold:
243
- has_predictions = True
244
- score_class = get_score_class(pred['score'])
 
245
  st.markdown(f"""
246
- <div class='result-box'>
247
- <span class='{score_class}' style='font-weight: 500;'>
248
- {pred['score']:.1%}
249
- </span> - {translate_label(pred['label'])}
250
  </div>
251
  """, unsafe_allow_html=True)
252
- if not has_predictions:
253
- st.info("Keine ausreichend sicheren Vorhersagen.")
254
- else:
255
- st.info("Bitte laden Sie ein Röntgenbild hoch (JPEG, PNG)")
256
-
257
- # Script pour la synchronisation du thème
258
- st.markdown("""
259
- <script>
260
- function updateTheme(isDark) {
261
- document.documentElement.setAttribute('data-theme', isDark ? 'dark' : 'light');
262
- }
263
-
264
- window.addEventListener('message', function(e) {
265
- if (e.data.type === 'theme-change') {
266
- updateTheme(e.data.theme === 'dark');
267
- }
268
- });
269
-
270
- updateTheme(window.matchMedia('(prefers-color-scheme: dark)').matches);
271
- </script>
272
- """, unsafe_allow_html=True)
273
 
274
  if __name__ == "__main__":
275
  main()
 
1
  import streamlit as st
2
  from transformers import pipeline
3
  from PIL import Image, ImageDraw
4
+ import numpy as np
5
+ import colorsys
6
 
7
  st.set_page_config(
8
  page_title="Fraktur Detektion",
 
12
 
13
  st.markdown("""
14
  <style>
 
15
  .stApp {
16
+ background: #f0f2f5 !important;
 
 
17
  }
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  .block-container {
20
+ padding-top: 0 !important;
21
+ padding-bottom: 0 !important;
22
+ max-width: 1400px !important;
23
  }
24
 
25
+ .upload-container {
26
+ background: white;
27
+ padding: 1.5rem;
28
+ border-radius: 10px;
29
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
30
+ margin-bottom: 1rem;
31
+ text-align: center;
32
  }
33
 
34
+ .results-container {
35
+ background: white;
36
+ padding: 1.5rem;
37
+ border-radius: 10px;
38
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
39
  }
40
 
41
+ .result-box {
42
+ background: #f8f9fa;
43
+ padding: 0.75rem;
44
+ border-radius: 8px;
45
+ margin: 0.5rem 0;
46
+ border: 1px solid #e9ecef;
47
  }
48
 
49
+ h1, h2, h3, h4, p {
50
+ color: #1a1a1a !important;
51
+ margin: 0.5rem 0 !important;
 
 
 
 
 
52
  }
53
 
54
+ .stImage {
55
+ background: white;
56
+ padding: 0.5rem;
57
+ border-radius: 8px;
58
+ box-shadow: 0 1px 3px rgba(0,0,0,0.1);
 
59
  }
60
 
61
+ .stImage > img {
62
+ max-height: 300px !important;
63
+ width: auto !important;
64
+ margin: 0 auto !important;
65
+ display: block !important;
66
  }
67
 
68
+ [data-testid="stFileUploader"] {
69
+ width: 100% !important;
 
70
  }
71
 
72
+ .stButton > button {
73
+ width: 200px;
74
+ background-color: #0066cc !important;
75
+ color: white !important;
76
+ border: none !important;
77
+ padding: 0.5rem 1rem !important;
78
+ border-radius: 5px !important;
79
+ transition: all 0.3s ease !important;
80
  }
81
 
82
+ .stButton > button:hover {
83
+ background-color: #0052a3 !important;
84
+ transform: translateY(-1px);
85
  }
86
 
87
+ #MainMenu, footer, header, [data-testid="stToolbar"] {
88
+ display: none !important;
89
  }
90
 
91
+ /* Hide deprecation warning */
92
+ [data-testid="stExpander"], .element-container:has(>.stAlert) {
93
+ display: none !important;
94
  }
95
  </style>
96
  """, unsafe_allow_html=True)
 
107
  def translate_label(label):
108
  translations = {
109
  "fracture": "Knochenbruch",
110
+ "no fracture": "Kein Knochenbruch",
111
  "normal": "Normal",
112
  "abnormal": "Auffällig",
113
  "F1": "Knochenbruch",
 
115
  }
116
  return translations.get(label.lower(), label)
117
 
118
+ def create_heatmap_overlay(image, box, score):
119
+ overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
120
+ draw = ImageDraw.Draw(overlay)
121
+
122
+ def get_temp_color(value):
123
+ if value > 0.8:
124
+ return (255, 0, 0) # Rouge vif
125
+ elif value > 0.6:
126
+ return (255, 69, 0) # Rouge-orange
127
+ elif value > 0.4:
128
+ return (255, 165, 0) # Orange
129
+ else:
130
+ return (255, 255, 0) # Jaune
131
+
132
+ x1, y1 = box['xmin'], box['ymin']
133
+ x2, y2 = box['xmax'], box['ymax']
134
+ width = x2 - x1
135
+ height = y2 - y1
136
+
137
+ steps = 30
138
+ for i in range(steps):
139
+ alpha = int(255 * (1 - (i / steps)) * 0.7)
140
+ base_color = get_temp_color(score)
141
+ color = base_color + (alpha,)
142
+
143
+ shrink_x = (i * width) / (steps * 2)
144
+ shrink_y = (i * height) / (steps * 2)
145
+
146
+ draw.rectangle(
147
+ [x1 + shrink_x, y1 + shrink_y, x2 - shrink_x, y2 - shrink_y],
148
+ fill=color,
149
+ outline=None
150
+ )
151
+
152
+ border_color = get_temp_color(score) + (200,)
153
+ draw.rectangle([x1, y1, x2, y2], outline=border_color, width=2)
154
+
155
+ return overlay
156
+
157
  def draw_boxes(image, predictions):
158
+ result_image = image.copy().convert('RGBA')
159
+
160
+ sorted_predictions = sorted(predictions, key=lambda x: x['score'])
161
+
162
+ for pred in sorted_predictions:
163
  box = pred['box']
164
  score = pred['score']
165
 
166
+ heatmap = create_heatmap_overlay(image, box, score)
167
+ result_image = Image.alpha_composite(result_image, heatmap)
 
 
168
 
169
+ draw = ImageDraw.Draw(result_image)
170
+ temp = 36.5 + (score * 2.5)
171
+ label = f"{translate_label(pred['label'])} ({score:.1%}) {temp:.1f}°C"
 
 
 
 
172
 
173
+ text_bbox = draw.textbbox((box['xmin'], box['ymin']-25), label)
174
+ padding = 3
175
+ text_bbox = (
176
+ text_bbox[0]-padding, text_bbox[1]-padding,
177
+ text_bbox[2]+padding, text_bbox[3]+padding
178
  )
179
+ draw.rectangle(text_bbox, fill="#000000CC")
180
 
 
 
 
 
 
 
 
 
181
  draw.text(
182
+ (box['xmin'], box['ymin']-25),
183
  label,
184
+ fill="#FFFFFF",
185
+ stroke_width=1,
186
+ stroke_fill="#000000"
187
  )
188
 
189
+ return result_image
190
 
191
  def main():
192
  models = load_models()
193
+
194
+ with st.container():
195
+ st.write("### 📤 Röntgenbild hochladen")
196
+ uploaded_file = st.file_uploader("Bild auswählen", type=['png', 'jpg', 'jpeg'], label_visibility="collapsed")
 
 
 
197
 
198
+ col1, col2 = st.columns([2, 1])
199
+ with col1:
200
  conf_threshold = st.slider(
201
  "Konfidenzschwelle",
202
  min_value=0.0, max_value=1.0,
203
+ value=0.60, step=0.05,
204
+ label_visibility="visible"
205
  )
206
+ with col2:
207
+ analyze_button = st.button("Analysieren")
208
 
209
+ if uploaded_file and analyze_button:
210
+ with st.spinner("Bild wird analysiert..."):
211
  image = Image.open(uploaded_file)
212
+ results_container = st.container()
213
 
214
+ predictions_watcher = models["KnochenWächter"](image)
215
+ predictions_master = models["RöntgenMeister"](image)
216
+ predictions_locator = models["KnochenAuge"](image)
217
 
218
+ has_fracture = False
219
+ max_fracture_score = 0
220
+ filtered_locations = [p for p in predictions_locator
221
+ if p['score'] >= conf_threshold
222
+ and 'fracture' in p['label'].lower()]
223
 
224
+ for pred in predictions_watcher:
225
+ if pred['score'] >= conf_threshold and 'fracture' in pred['label'].lower():
226
+ has_fracture = True
227
+ max_fracture_score = max(max_fracture_score, pred['score'])
 
 
 
228
 
229
+ with results_container:
230
+ st.write("### 🔍 Analyse Ergebnisse")
231
+ col1, col2 = st.columns(2)
232
+
233
+ with col1:
234
+ st.write("#### 🤖 KI-Diagnose")
235
+
236
+ st.write("##### 🛡️ KnochenWächter")
237
+ for pred in predictions_watcher:
238
+ if pred['score'] >= conf_threshold:
239
+ confidence_color = '#0066cc' if pred['score'] > 0.7 else '#ffa500'
240
+ label_lower = pred['label'].lower()
241
+ if 'fracture' in label_lower:
242
+ has_fracture = True
243
+ max_fracture_score = max(max_fracture_score, pred['score'])
244
+ st.markdown(f"""
245
+ <div class="result-box" style="color: #1a1a1a;">
246
+ <span style="color: {confidence_color}; font-weight: 500;">
247
+ {pred['score']:.1%}
248
+ </span> - {translate_label(pred['label'])}
249
+ </div>
250
+ """, unsafe_allow_html=True)
251
+
252
+ st.write("#### 🎓 RöntgenMeister")
253
+ for pred in predictions_master:
254
+ if pred['score'] >= conf_threshold:
255
+ confidence_color = '#0066cc' if pred['score'] > 0.7 else '#ffa500'
256
+ st.markdown(f"""
257
+ <div class="result-box" style="color: #1a1a1a;">
258
+ <span style="color: {confidence_color}; font-weight: 500;">
259
+ {pred['score']:.1%}
260
+ </span> - {translate_label(pred['label'])}
261
+ </div>
262
+ """, unsafe_allow_html=True)
263
+
264
+ if max_fracture_score > 0:
265
+ st.write("#### 📊 Wahrscheinlichkeit")
266
+ no_fracture_prob = 1 - max_fracture_score
267
  st.markdown(f"""
268
+ <div class="result-box" style="color: #1a1a1a;">
269
+ Knochenbruch: <strong style="color: #0066cc">{max_fracture_score:.1%}</strong><br>
270
+ Kein Knochenbruch: <strong style="color: #ffa500">{no_fracture_prob:.1%}</strong>
 
271
  </div>
272
  """, unsafe_allow_html=True)
273
+
274
+ with col2:
275
+ predictions = models["KnochenAuge"](image)
276
+ filtered_preds = [p for p in predictions if p['score'] >= conf_threshold
277
+ and 'fracture' in p['label'].lower()]
278
+
279
+ if filtered_preds:
280
+ st.write("#### 🎯 Fraktur Lokalisation")
281
+ result_image = draw_boxes(image, filtered_preds)
282
+ st.image(result_image, use_container_width=True)
283
+ else:
284
+ st.write("#### 🖼️ Röntgenbild")
285
+ st.image(image, use_container_width=True)
 
 
 
 
 
 
 
 
286
 
287
  if __name__ == "__main__":
288
  main()