yassonee commited on
Commit
e237c4f
·
verified ·
1 Parent(s): 8ab1fd2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +182 -205
app.py CHANGED
@@ -1,8 +1,7 @@
1
  import streamlit as st
2
  from transformers import pipeline
3
  from PIL import Image, ImageDraw
4
- import numpy as np
5
- import colorsys
6
 
7
  st.set_page_config(
8
  page_title="Fraktur Detektion",
@@ -12,85 +11,98 @@ st.set_page_config(
12
 
13
  st.markdown("""
14
  <style>
 
15
  .stApp {
16
- background: #f0f2f5 !important;
 
 
17
  }
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  .block-container {
20
- padding-top: 0 !important;
21
- padding-bottom: 0 !important;
22
- max-width: 1400px !important;
23
  }
24
 
25
- .upload-container {
26
- background: white;
27
- padding: 1.5rem;
28
- border-radius: 10px;
29
- box-shadow: 0 2px 4px rgba(0,0,0,0.1);
30
- margin-bottom: 1rem;
31
- text-align: center;
32
  }
33
 
34
- .results-container {
35
- background: white;
36
- padding: 1.5rem;
37
- border-radius: 10px;
38
- box-shadow: 0 2px 4px rgba(0,0,0,0.1);
39
  }
40
 
41
- .result-box {
42
- background: #f8f9fa;
43
- padding: 0.75rem;
44
- border-radius: 8px;
45
- margin: 0.5rem 0;
46
- border: 1px solid #e9ecef;
47
  }
48
 
49
- h1, h2, h3, h4, p {
50
- color: #1a1a1a !important;
51
- margin: 0.5rem 0 !important;
 
 
 
 
 
52
  }
53
 
54
- .stImage {
55
- background: white;
56
- padding: 0.5rem;
57
- border-radius: 8px;
58
- box-shadow: 0 1px 3px rgba(0,0,0,0.1);
 
59
  }
60
 
61
- .stImage > img {
62
- max-height: 300px !important;
63
- width: auto !important;
64
- margin: 0 auto !important;
65
- display: block !important;
66
  }
67
 
68
- [data-testid="stFileUploader"] {
69
- width: 100% !important;
 
70
  }
71
 
72
- .stButton > button {
73
- width: 200px;
74
- background-color: #0066cc !important;
75
- color: white !important;
76
- border: none !important;
77
- padding: 0.5rem 1rem !important;
78
- border-radius: 5px !important;
79
- transition: all 0.3s ease !important;
80
  }
81
 
82
- .stButton > button:hover {
83
- background-color: #0052a3 !important;
84
- transform: translateY(-1px);
85
  }
86
 
87
- #MainMenu, footer, header, [data-testid="stToolbar"] {
88
- display: none !important;
89
  }
90
 
91
- /* Hide deprecation warning */
92
- [data-testid="stExpander"], .element-container:has(>.stAlert) {
93
- display: none !important;
94
  }
95
  </style>
96
  """, unsafe_allow_html=True)
@@ -107,7 +119,7 @@ def load_models():
107
  def translate_label(label):
108
  translations = {
109
  "fracture": "Knochenbruch",
110
- "no fracture": "Kein Knochenbruch",
111
  "normal": "Normal",
112
  "abnormal": "Auffällig",
113
  "F1": "Knochenbruch",
@@ -115,184 +127,149 @@ def translate_label(label):
115
  }
116
  return translations.get(label.lower(), label)
117
 
118
- def create_heatmap_overlay(image, box, score):
119
- overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
120
- draw = ImageDraw.Draw(overlay)
121
-
122
- def get_temp_color(value):
123
- if value > 0.8:
124
- return (255, 0, 0) # Rouge vif
125
- elif value > 0.6:
126
- return (255, 69, 0) # Rouge-orange
127
- elif value > 0.4:
128
- return (255, 165, 0) # Orange
129
- else:
130
- return (255, 255, 0) # Jaune
131
-
132
- x1, y1 = box['xmin'], box['ymin']
133
- x2, y2 = box['xmax'], box['ymax']
134
- width = x2 - x1
135
- height = y2 - y1
136
-
137
- steps = 30
138
- for i in range(steps):
139
- alpha = int(255 * (1 - (i / steps)) * 0.7)
140
- base_color = get_temp_color(score)
141
- color = base_color + (alpha,)
142
-
143
- shrink_x = (i * width) / (steps * 2)
144
- shrink_y = (i * height) / (steps * 2)
145
-
146
- draw.rectangle(
147
- [x1 + shrink_x, y1 + shrink_y, x2 - shrink_x, y2 - shrink_y],
148
- fill=color,
149
- outline=None
150
- )
151
-
152
- border_color = get_temp_color(score) + (200,)
153
- draw.rectangle([x1, y1, x2, y2], outline=border_color, width=2)
154
-
155
- return overlay
156
-
157
  def draw_boxes(image, predictions):
158
- result_image = image.copy().convert('RGBA')
159
-
160
- sorted_predictions = sorted(predictions, key=lambda x: x['score'])
161
-
162
- for pred in sorted_predictions:
163
  box = pred['box']
164
  score = pred['score']
165
 
166
- heatmap = create_heatmap_overlay(image, box, score)
167
- result_image = Image.alpha_composite(result_image, heatmap)
 
 
168
 
169
- draw = ImageDraw.Draw(result_image)
170
- temp = 36.5 + (score * 2.5)
171
- label = f"{translate_label(pred['label'])} ({score:.1%}) {temp:.1f}°C"
 
 
 
 
172
 
173
- text_bbox = draw.textbbox((box['xmin'], box['ymin']-25), label)
174
- padding = 3
175
- text_bbox = (
176
- text_bbox[0]-padding, text_bbox[1]-padding,
177
- text_bbox[2]+padding, text_bbox[3]+padding
178
  )
179
- draw.rectangle(text_bbox, fill="#000000CC")
180
 
 
 
 
 
 
 
 
 
181
  draw.text(
182
- (box['xmin'], box['ymin']-25),
183
  label,
184
- fill="#FFFFFF",
185
- stroke_width=1,
186
- stroke_fill="#000000"
187
  )
188
 
189
- return result_image
190
 
191
  def main():
192
  models = load_models()
193
-
194
- with st.container():
195
- st.write("### 📤 Röntgenbild hochladen")
196
- uploaded_file = st.file_uploader("Bild auswählen", type=['png', 'jpg', 'jpeg'], label_visibility="collapsed")
 
 
 
197
 
198
- col1, col2 = st.columns([2, 1])
199
- with col1:
200
  conf_threshold = st.slider(
201
  "Konfidenzschwelle",
202
  min_value=0.0, max_value=1.0,
203
- value=0.60, step=0.05,
204
- label_visibility="visible"
205
  )
206
- with col2:
207
- analyze_button = st.button("Analysieren")
208
 
209
- if uploaded_file and analyze_button:
210
- with st.spinner("Bild wird analysiert..."):
211
  image = Image.open(uploaded_file)
212
- results_container = st.container()
213
 
214
- predictions_watcher = models["KnochenWächter"](image)
215
- predictions_master = models["RöntgenMeister"](image)
216
- predictions_locator = models["KnochenAuge"](image)
217
 
218
- has_fracture = False
219
- max_fracture_score = 0
220
- filtered_locations = [p for p in predictions_locator
221
- if p['score'] >= conf_threshold
222
- and 'fracture' in p['label'].lower()]
223
 
224
- for pred in predictions_watcher:
225
- if pred['score'] >= conf_threshold and 'fracture' in pred['label'].lower():
226
- has_fracture = True
227
- max_fracture_score = max(max_fracture_score, pred['score'])
 
 
 
228
 
229
- with results_container:
230
- st.write("### 🔍 Analyse Ergebnisse")
231
- col1, col2 = st.columns(2)
232
-
233
- with col1:
234
- st.write("#### 🤖 KI-Diagnose")
235
-
236
- st.write("##### 🛡️ KnochenWächter")
237
- for pred in predictions_watcher:
238
- if pred['score'] >= conf_threshold:
239
- confidence_color = '#0066cc' if pred['score'] > 0.7 else '#ffa500'
240
- label_lower = pred['label'].lower()
241
- if 'fracture' in label_lower:
242
- has_fracture = True
243
- max_fracture_score = max(max_fracture_score, pred['score'])
244
- st.markdown(f"""
245
- <div class="result-box" style="color: #1a1a1a;">
246
- <span style="color: {confidence_color}; font-weight: 500;">
247
- {pred['score']:.1%}
248
- </span> - {translate_label(pred['label'])}
249
- </div>
250
- """, unsafe_allow_html=True)
251
-
252
- st.write("#### 🎓 RöntgenMeister")
253
- for pred in predictions_master:
254
- if pred['score'] >= conf_threshold:
255
- confidence_color = '#0066cc' if pred['score'] > 0.7 else '#ffa500'
256
- st.markdown(f"""
257
- <div class="result-box" style="color: #1a1a1a;">
258
- <span style="color: {confidence_color}; font-weight: 500;">
259
- {pred['score']:.1%}
260
- </span> - {translate_label(pred['label'])}
261
- </div>
262
- """, unsafe_allow_html=True)
263
-
264
- if max_fracture_score > 0:
265
- st.write("#### 📊 Wahrscheinlichkeit")
266
- no_fracture_prob = 1 - max_fracture_score
267
  st.markdown(f"""
268
- <div class="result-box" style="color: #1a1a1a;">
269
- Knochenbruch: <strong style="color: #0066cc">{max_fracture_score:.1%}</strong><br>
270
- Kein Knochenbruch: <strong style="color: #ffa500">{no_fracture_prob:.1%}</strong>
 
271
  </div>
272
  """, unsafe_allow_html=True)
273
-
274
- with col2:
275
- predictions = models["KnochenAuge"](image)
276
- # Debug: Afficher toutes les prédictions avant filtrage
277
- st.write("Debug - Toutes les prédictions:")
278
- for p in predictions:
279
- st.write(f"Label: {p['label']}, Score: {p['score']}")
280
-
281
- filtered_preds = [p for p in predictions if p['score'] >= conf_threshold
282
- and 'fracture' in p['label'].lower()]
283
-
284
- # Debug: Afficher les prédictions filtrées
285
- st.write("Debug - Prédictions filtrées:")
286
- for p in filtered_preds:
287
- st.write(f"Label: {p['label']}, Score: {p['score']}, Box: {p['box']}")
288
-
289
- if filtered_preds:
290
- st.write("#### 🎯 Fraktur Lokalisation")
291
- result_image = draw_boxes(image, filtered_preds)
292
- st.image(result_image, use_container_width=True)
293
- else:
294
- st.write("#### 🖼️ Röntgenbild")
295
- st.image(image, use_container_width=True)
296
 
297
  if __name__ == "__main__":
298
  main()
 
1
  import streamlit as st
2
  from transformers import pipeline
3
  from PIL import Image, ImageDraw
4
+ import torch
 
5
 
6
  st.set_page_config(
7
  page_title="Fraktur Detektion",
 
11
 
12
  st.markdown("""
13
  <style>
14
+ /* Reset et base */
15
  .stApp {
16
+ background-color: var(--background-color) !important;
17
+ padding: 0 !important;
18
+ overflow: hidden !important;
19
  }
20
 
21
+ /* Variables de thème */
22
+ [data-theme="light"] {
23
+ --background-color: #ffffff;
24
+ --text-color: #1f2937;
25
+ --border-color: #e5e7eb;
26
+ --secondary-bg: #f3f4f6;
27
+ }
28
+
29
+ [data-theme="dark"] {
30
+ --background-color: #1f2937;
31
+ --text-color: #f3f4f6;
32
+ --border-color: #4b5563;
33
+ --secondary-bg: #374151;
34
+ }
35
+
36
+ /* Layout principal */
37
  .block-container {
38
+ padding: 0.5rem !important;
39
+ max-width: 100% !important;
 
40
  }
41
 
42
+ /* Contrôles et upload */
43
+ .uploadedFile {
44
+ border: 1px dashed var(--border-color);
45
+ border-radius: 0.375rem;
46
+ padding: 0.25rem;
47
+ background: var(--secondary-bg);
 
48
  }
49
 
50
+ /* Ajustement des colonnes */
51
+ [data-testid="column"] {
52
+ padding: 0 0.5rem !important;
 
 
53
  }
54
 
55
+ /* Images adaptatives */
56
+ .stImage > img {
57
+ width: 100% !important;
58
+ height: auto !important;
59
+ max-height: 400px !important;
60
+ object-fit: contain !important;
61
  }
62
 
63
+ /* Résultats */
64
+ .result-box {
65
+ padding: 0.375rem;
66
+ border-radius: 0.375rem;
67
+ margin: 0.25rem 0;
68
+ background: var(--secondary-bg);
69
+ border: 1px solid var(--border-color);
70
+ color: var(--text-color);
71
  }
72
 
73
+ /* Titres */
74
+ h2, h3 {
75
+ margin: 0 !important;
76
+ padding: 0.5rem 0 !important;
77
+ font-size: 1rem !important;
78
+ color: var(--text-color) !important;
79
  }
80
 
81
+ /* Nettoyage des éléments inutiles */
82
+ #MainMenu, footer, header, .viewerBadge_container__1QSob, .stDeployButton {
83
+ display: none !important;
 
 
84
  }
85
 
86
+ /* Ajustements espacement */
87
+ div[data-testid="stVerticalBlock"] {
88
+ gap: 0.5rem !important;
89
  }
90
 
91
+ .element-container {
92
+ margin: 0.25rem 0 !important;
 
 
 
 
 
 
93
  }
94
 
95
+ /* Éléments spécifiques */
96
+ .high-confidence {
97
+ color: #22c55e !important;
98
  }
99
 
100
+ .medium-confidence {
101
+ color: #eab308 !important;
102
  }
103
 
104
+ .low-confidence {
105
+ color: #dc2626 !important;
 
106
  }
107
  </style>
108
  """, unsafe_allow_html=True)
 
119
  def translate_label(label):
120
  translations = {
121
  "fracture": "Knochenbruch",
122
+ "no fracture": "Kein Bruch",
123
  "normal": "Normal",
124
  "abnormal": "Auffällig",
125
  "F1": "Knochenbruch",
 
127
  }
128
  return translations.get(label.lower(), label)
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  def draw_boxes(image, predictions):
131
+ draw = ImageDraw.Draw(image)
132
+ for pred in predictions:
 
 
 
133
  box = pred['box']
134
  score = pred['score']
135
 
136
+ # Calcul de la température simulée basée sur le score
137
+ # Score 1.0 -> 39°C (forte probabilité = "plus chaud")
138
+ # Score 0.6 -> 36.5°C (seuil minimum = "normal")
139
+ temp = 36.5 + (score - 0.6) * (39 - 36.5) / 0.4
140
 
141
+ # Couleur basée sur le score
142
+ if score > 0.8:
143
+ color = "#dc2626" # rouge pour haute confiance
144
+ elif score > 0.7:
145
+ color = "#ea580c" # orange pour confiance moyenne-haute
146
+ else:
147
+ color = "#eab308" # jaune pour confiance moyenne
148
 
149
+ # Dessiner la boîte
150
+ draw.rectangle(
151
+ [(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
152
+ outline=color,
153
+ width=2
154
  )
 
155
 
156
+ # Créer le label avec température
157
+ label = f"{translate_label(pred['label'])} ({score:.1%} • {temp:.1f}°C)"
158
+
159
+ # Fond pour le texte
160
+ text_bbox = draw.textbbox((box['xmin'], box['ymin']-15), label)
161
+ draw.rectangle(text_bbox, fill=color)
162
+
163
+ # Texte
164
  draw.text(
165
+ (box['xmin'], box['ymin']-15),
166
  label,
167
+ fill="white"
 
 
168
  )
169
 
170
+ return image
171
 
172
  def main():
173
  models = load_models()
174
+
175
+ # Disposition en deux colonnes principales
176
+ col1, col2 = st.columns([1, 2])
177
+
178
+ with col1:
179
+ st.markdown("### 📤 Röntgenbild Upload")
180
+ uploaded_file = st.file_uploader("", type=['png', 'jpg', 'jpeg'])
181
 
182
+ if uploaded_file:
 
183
  conf_threshold = st.slider(
184
  "Konfidenzschwelle",
185
  min_value=0.0, max_value=1.0,
186
+ value=0.60, step=0.05
 
187
  )
 
 
188
 
189
+ with col2:
190
+ if uploaded_file:
191
  image = Image.open(uploaded_file)
 
192
 
193
+ st.markdown("### 🔍 Meinung der KI-Experten")
 
 
194
 
195
+ # KnochenAuge Analysis (Localisation)
196
+ st.markdown("#### 👁️ Das KnochenAuge - Lokalisation")
197
+ predictions = models["KnochenAuge"](image)
198
+ filtered_preds = [p for p in predictions if p['score'] >= conf_threshold]
 
199
 
200
+ if filtered_preds:
201
+ result_image = image.copy()
202
+ result_image = draw_boxes(result_image, filtered_preds)
203
+ st.image(result_image, use_container_width=True)
204
+ else:
205
+ st.image(image, use_container_width=True)
206
+ st.info("Keine signifikanten Auffälligkeiten gefunden.")
207
 
208
+ # Other Models Analysis
209
+ st.markdown("#### 🎯 KI-Analyse")
210
+ col_left, col_right = st.columns(2)
211
+
212
+ def get_score_class(score):
213
+ if score > 0.8:
214
+ return "high-confidence"
215
+ elif score > 0.7:
216
+ return "medium-confidence"
217
+ return "low-confidence"
218
+
219
+ with col_left:
220
+ st.markdown("**🛡️ Der KnochenWächter**")
221
+ predictions = models["KnochenWächter"](image)
222
+ has_predictions = False
223
+ for pred in predictions:
224
+ if pred['score'] >= conf_threshold:
225
+ has_predictions = True
226
+ score_class = get_score_class(pred['score'])
227
+ st.markdown(f"""
228
+ <div class='result-box'>
229
+ <span class='{score_class}' style='font-weight: 500;'>
230
+ {pred['score']:.1%}
231
+ </span> - {translate_label(pred['label'])}
232
+ </div>
233
+ """, unsafe_allow_html=True)
234
+ if not has_predictions:
235
+ st.info("Keine ausreichend sicheren Vorhersagen.")
236
+
237
+ with col_right:
238
+ st.markdown("**🎓 Der RöntgenMeister**")
239
+ predictions = models["RöntgenMeister"](image)
240
+ has_predictions = False
241
+ for pred in predictions:
242
+ if pred['score'] >= conf_threshold:
243
+ has_predictions = True
244
+ score_class = get_score_class(pred['score'])
 
245
  st.markdown(f"""
246
+ <div class='result-box'>
247
+ <span class='{score_class}' style='font-weight: 500;'>
248
+ {pred['score']:.1%}
249
+ </span> - {translate_label(pred['label'])}
250
  </div>
251
  """, unsafe_allow_html=True)
252
+ if not has_predictions:
253
+ st.info("Keine ausreichend sicheren Vorhersagen.")
254
+ else:
255
+ st.info("Bitte laden Sie ein Röntgenbild hoch (JPEG, PNG)")
256
+
257
+ # Script pour la synchronisation du thème
258
+ st.markdown("""
259
+ <script>
260
+ function updateTheme(isDark) {
261
+ document.documentElement.setAttribute('data-theme', isDark ? 'dark' : 'light');
262
+ }
263
+
264
+ window.addEventListener('message', function(e) {
265
+ if (e.data.type === 'theme-change') {
266
+ updateTheme(e.data.theme === 'dark');
267
+ }
268
+ });
269
+
270
+ updateTheme(window.matchMedia('(prefers-color-scheme: dark)').matches);
271
+ </script>
272
+ """, unsafe_allow_html=True)
 
 
273
 
274
  if __name__ == "__main__":
275
  main()