yassonee commited on
Commit
8d54860
·
verified ·
1 Parent(s): e59f527

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -116
app.py CHANGED
@@ -1,7 +1,9 @@
1
  import streamlit as st
2
  from transformers import pipeline
3
  from PIL import Image, ImageDraw
4
- import torch
 
 
5
 
6
  st.set_page_config(
7
  page_title="Fraktur Detektion",
@@ -11,34 +13,37 @@ st.set_page_config(
11
 
12
  st.markdown("""
13
  <style>
14
- /* Base styles */
15
  .stApp {
16
  background: #f0f2f5 !important;
17
  }
18
 
19
  .block-container {
20
- padding: 1rem !important;
 
21
  max-width: 1400px !important;
22
  margin: 0 auto !important;
23
  }
24
 
25
- /* Custom containers */
26
- .center-upload {
 
 
27
  background: white;
28
- padding: 2rem;
29
  border-radius: 10px;
30
  box-shadow: 0 2px 4px rgba(0,0,0,0.1);
31
- margin-bottom: 2rem;
32
- text-align: center;
33
  }
34
 
35
- .analysis-container {
36
- background: white;
37
- padding: 1.5rem;
38
- border-radius: 10px;
39
- box-shadow: 0 2px 4px rgba(0,0,0,0.1);
40
- margin-bottom: 1rem;
41
- animation: slideIn 0.5s ease-out;
 
 
 
42
  }
43
 
44
  .result-box {
@@ -49,13 +54,11 @@ st.markdown("""
49
  border: 1px solid #e9ecef;
50
  }
51
 
52
- /* Text styles */
53
  h1, h2, h3, h4, p {
54
  color: #1a1a1a !important;
55
  margin: 0.5rem 0 !important;
56
  }
57
 
58
- /* Image styles */
59
  .stImage {
60
  background: white;
61
  padding: 0.5rem;
@@ -64,38 +67,18 @@ st.markdown("""
64
  }
65
 
66
  .stImage > img {
67
- max-height: 250px !important;
68
  width: auto !important;
69
  margin: 0 auto !important;
70
  display: block !important;
71
  }
72
 
73
- /* Animations */
74
- @keyframes slideIn {
75
- from {
76
- opacity: 0;
77
- transform: translateY(-10px);
78
- }
79
- to {
80
- opacity: 1;
81
- transform: translateY(0);
82
- }
83
- }
84
-
85
- /* Hide unnecessary elements */
86
- #MainMenu, footer {
87
- display: none !important;
88
- }
89
-
90
- /* Custom columns spacing */
91
- [data-testid="column"] {
92
- padding: 0.5rem !important;
93
- background: transparent !important;
94
  }
95
 
96
- /* Button styling */
97
  .stButton > button {
98
- width: 200px;
99
  background-color: #0066cc !important;
100
  color: white !important;
101
  border: none !important;
@@ -108,6 +91,19 @@ st.markdown("""
108
  background-color: #0052a3 !important;
109
  transform: translateY(-1px);
110
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  </style>
112
  """, unsafe_allow_html=True)
113
 
@@ -125,109 +121,148 @@ def translate_label(label):
125
  "fracture": "Knochenbruch",
126
  "no fracture": "Kein Bruch",
127
  "normal": "Normal",
128
- "abnormal": "Auffällig"
 
 
129
  }
130
  return translations.get(label.lower(), label)
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  def draw_boxes(image, predictions):
133
- draw = ImageDraw.Draw(image)
 
 
134
  for pred in predictions:
135
  box = pred['box']
136
- label = f"{translate_label(pred['label'])} ({pred['score']:.2%})"
137
- color = "#0066cc" if pred['score'] > 0.7 else "#ffa500"
138
 
 
 
 
 
 
 
139
  draw.rectangle(
140
  [(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
141
- outline=color,
142
  width=2
143
  )
144
 
145
- text_bbox = draw.textbbox((box['xmin'], box['ymin']-15), label)
146
- draw.rectangle(text_bbox, fill=color)
147
- draw.text((box['xmin'], box['ymin']-15), label, fill="white")
148
- return image
 
 
149
 
150
  def main():
151
  models = load_models()
152
 
153
- # Initial upload section
154
- st.markdown('<div class="center-upload">', unsafe_allow_html=True)
 
 
 
 
 
 
 
155
  st.markdown("### 📤 Röntgenbild Upload")
156
- uploaded_files = st.file_uploader("", type=['png', 'jpg', 'jpeg'], accept_multiple_files=True)
157
 
158
  conf_threshold = st.slider(
159
  "Konfidenzschwelle",
160
  min_value=0.0, max_value=1.0,
161
- value=0.60, step=0.05,
162
- key='confidence'
163
  )
164
 
165
- analyze_button = st.button("Analysieren", key='analyze')
166
  st.markdown('</div>', unsafe_allow_html=True)
167
-
168
- # Analysis section
169
- if analyze_button and uploaded_files:
170
- for idx, uploaded_file in enumerate(uploaded_files):
171
- st.markdown(f'<div class="analysis-container">', unsafe_allow_html=True)
 
 
 
 
172
 
173
- with st.spinner("Analysiere Bild..."):
174
- image = Image.open(uploaded_file)
175
-
176
- # Create three columns
177
- col1, col2, col3 = st.columns(3)
178
-
179
- # Column 1: Original Image
180
- with col1:
181
- st.markdown("### 🖼️ Original")
182
- st.image(image, use_column_width=True)
183
 
184
- # Column 2: AI Analysis
185
- with col2:
186
- st.markdown("### 🤖 KI-Analyse")
187
-
188
- # KnochenWächter results
189
- predictions = models["KnochenWächter"](image)
190
- st.markdown("#### 🛡️ KnochenWächter")
191
- for pred in predictions:
192
- if pred['score'] >= conf_threshold:
193
- st.markdown(f"""
194
- <div class="result-box">
195
- <span style='color: {"#0066cc" if pred["score"] > 0.7 else "#ffa500"}; font-weight: 500;'>
196
- {pred['score']:.1%}
197
- </span> - {translate_label(pred['label'])}
198
- </div>
199
- """, unsafe_allow_html=True)
200
-
201
- # RöntgenMeister results
202
- predictions = models["RöntgenMeister"](image)
203
- st.markdown("#### 🎓 RöntgenMeister")
204
- for pred in predictions:
205
- if pred['score'] >= conf_threshold:
206
- st.markdown(f"""
207
- <div class="result-box">
208
- <span style='color: {"#0066cc" if pred["score"] > 0.7 else "#ffa500"}; font-weight: 500;'>
209
- {pred['score']:.1%}
210
- </span> - {translate_label(pred['label'])}
211
- </div>
212
- """, unsafe_allow_html=True)
213
 
214
- # Column 3: Localization (only if fracture detected)
215
- with col3:
216
- predictions = models["KnochenAuge"](image)
217
- has_fracture = any(
218
- p['score'] >= conf_threshold and 'fracture' in p['label'].lower()
219
- for p in predictions
220
- )
221
-
222
- if has_fracture:
223
- st.markdown("### 🔍 Fraktur Lokalisation")
224
- filtered_preds = [p for p in predictions if p['score'] >= conf_threshold]
225
- if filtered_preds:
226
- result_image = image.copy()
227
- result_image = draw_boxes(result_image, filtered_preds)
228
- st.image(result_image, use_column_width=True)
229
 
230
- st.markdown('</div>', unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  if __name__ == "__main__":
233
  main()
 
1
  import streamlit as st
2
  from transformers import pipeline
3
  from PIL import Image, ImageDraw
4
+ import numpy as np
5
+ from PIL import ImageColor
6
+ import colorsys
7
 
8
  st.set_page_config(
9
  page_title="Fraktur Detektion",
 
13
 
14
  st.markdown("""
15
  <style>
 
16
  .stApp {
17
  background: #f0f2f5 !important;
18
  }
19
 
20
  .block-container {
21
+ padding-top: 0 !important;
22
+ padding-bottom: 0 !important;
23
  max-width: 1400px !important;
24
  margin: 0 auto !important;
25
  }
26
 
27
+ .main-container {
28
+ display: flex;
29
+ gap: 1rem;
30
+ padding: 1rem;
31
  background: white;
 
32
  border-radius: 10px;
33
  box-shadow: 0 2px 4px rgba(0,0,0,0.1);
34
+ margin: 1rem;
 
35
  }
36
 
37
+ .upload-section {
38
+ flex: 1;
39
+ padding: 1rem;
40
+ border-radius: 8px;
41
+ background: #f8f9fa;
42
+ }
43
+
44
+ .result-section {
45
+ flex: 2;
46
+ padding: 1rem;
47
  }
48
 
49
  .result-box {
 
54
  border: 1px solid #e9ecef;
55
  }
56
 
 
57
  h1, h2, h3, h4, p {
58
  color: #1a1a1a !important;
59
  margin: 0.5rem 0 !important;
60
  }
61
 
 
62
  .stImage {
63
  background: white;
64
  padding: 0.5rem;
 
67
  }
68
 
69
  .stImage > img {
70
+ max-height: 300px !important;
71
  width: auto !important;
72
  margin: 0 auto !important;
73
  display: block !important;
74
  }
75
 
76
+ [data-testid="stFileUploader"] {
77
+ width: 100% !important;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  }
79
 
 
80
  .stButton > button {
81
+ width: 100%;
82
  background-color: #0066cc !important;
83
  color: white !important;
84
  border: none !important;
 
91
  background-color: #0052a3 !important;
92
  transform: translateY(-1px);
93
  }
94
+
95
+ #MainMenu, footer, header {
96
+ display: none !important;
97
+ }
98
+
99
+ /* Hide deprecation warning */
100
+ [data-testid="stExpander"] {
101
+ display: none !important;
102
+ }
103
+
104
+ .element-container:has(>.stAlert) {
105
+ display: none !important;
106
+ }
107
  </style>
108
  """, unsafe_allow_html=True)
109
 
 
121
  "fracture": "Knochenbruch",
122
  "no fracture": "Kein Bruch",
123
  "normal": "Normal",
124
+ "abnormal": "Auffällig",
125
+ "F1": "Knochenbruch",
126
+ "NF": "Kein Bruch"
127
  }
128
  return translations.get(label.lower(), label)
129
 
130
+ def create_heatmap_overlay(image, box, score):
131
+ overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
132
+ draw = ImageDraw.Draw(overlay)
133
+
134
+ # Create gradient colors based on confidence
135
+ def get_heatmap_color(value):
136
+ # Convert to HSV for better control
137
+ hue = (1 - value) * 0.3 # 0.3 = reddish, 0 = red
138
+ saturation = 0.8
139
+ value = 0.9
140
+ # Convert back to RGB
141
+ rgb = colorsys.hsv_to_rgb(hue, saturation, value)
142
+ return tuple(int(x * 255) for x in rgb)
143
+
144
+ # Draw the heatmap with gradient
145
+ x1, y1 = box['xmin'], box['ymin']
146
+ x2, y2 = box['xmax'], box['ymax']
147
+
148
+ steps = 20
149
+ for i in range(steps):
150
+ alpha = int(255 * (1 - i/steps) * 0.6) # Gradient transparency
151
+ color = get_heatmap_color(score)
152
+ rect_color = color + (alpha,)
153
+
154
+ # Create shrinking rectangles for gradient effect
155
+ shrink = i * ((x2-x1)/(steps*2))
156
+ draw.rectangle([x1+shrink, y1+shrink, x2-shrink, y2-shrink],
157
+ fill=rect_color)
158
+
159
+ return overlay
160
+
161
  def draw_boxes(image, predictions):
162
+ # Create a copy of the image to work with
163
+ result_image = image.copy().convert('RGBA')
164
+
165
  for pred in predictions:
166
  box = pred['box']
167
+ score = pred['score']
168
+ label = f"{translate_label(pred['label'])} ({score:.2%})"
169
 
170
+ # Create and combine heatmap overlay
171
+ heatmap = create_heatmap_overlay(image, box, score)
172
+ result_image = Image.alpha_composite(result_image, heatmap)
173
+
174
+ # Draw border and label
175
+ draw = ImageDraw.Draw(result_image)
176
  draw.rectangle(
177
  [(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
178
+ outline="#FFFFFF",
179
  width=2
180
  )
181
 
182
+ # Add label with background
183
+ text_bbox = draw.textbbox((box['xmin'], box['ymin']-20), label)
184
+ draw.rectangle(text_bbox, fill="#000000AA")
185
+ draw.text((box['xmin'], box['ymin']-20), label, fill="white")
186
+
187
+ return result_image
188
 
189
  def main():
190
  models = load_models()
191
 
192
+ # Initialize session state
193
+ if 'analyzed' not in st.session_state:
194
+ st.session_state.analyzed = False
195
+
196
+ # Main container
197
+ st.markdown('<div class="main-container">', unsafe_allow_html=True)
198
+
199
+ # Upload section
200
+ st.markdown('<div class="upload-section">', unsafe_allow_html=True)
201
  st.markdown("### 📤 Röntgenbild Upload")
202
+ uploaded_file = st.file_uploader("", type=['png', 'jpg', 'jpeg'])
203
 
204
  conf_threshold = st.slider(
205
  "Konfidenzschwelle",
206
  min_value=0.0, max_value=1.0,
207
+ value=0.60, step=0.05
 
208
  )
209
 
210
+ analyze_button = st.button("Analysieren")
211
  st.markdown('</div>', unsafe_allow_html=True)
212
+
213
+ # Results section
214
+ st.markdown('<div class="result-section">', unsafe_allow_html=True)
215
+
216
+ if uploaded_file and analyze_button:
217
+ st.session_state.analyzed = True
218
+
219
+ with st.spinner("Analysiere Bild..."):
220
+ image = Image.open(uploaded_file)
221
 
222
+ col1, col2 = st.columns(2)
223
+
224
+ with col1:
225
+ st.markdown("### 🎯 KI-Analyse")
 
 
 
 
 
 
226
 
227
+ # KnochenWächter results
228
+ st.markdown("#### 🛡️ KnochenWächter")
229
+ predictions = models["KnochenWächter"](image)
230
+ for pred in predictions:
231
+ if pred['score'] >= conf_threshold:
232
+ st.markdown(f"""
233
+ <div class="result-box">
234
+ <span style="color: {'#0066cc' if pred['score'] > 0.7 else '#ffa500'}; font-weight: 500;">
235
+ {pred['score']:.1%}
236
+ </span> - {translate_label(pred['label'])}
237
+ </div>
238
+ """, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
+ # RöntgenMeister results
241
+ st.markdown("#### 🎓 RöntgenMeister")
242
+ predictions = models["RöntgenMeister"](image)
243
+ for pred in predictions:
244
+ if pred['score'] >= conf_threshold:
245
+ st.markdown(f"""
246
+ <div class="result-box">
247
+ <span style="color: {'#0066cc' if pred['score'] > 0.7 else '#ffa500'}; font-weight: 500;">
248
+ {pred['score']:.1%}
249
+ </span> - {translate_label(pred['label'])}
250
+ </div>
251
+ """, unsafe_allow_html=True)
 
 
 
252
 
253
+ with col2:
254
+ st.markdown("### 🔍 Visualisierung")
255
+ predictions = models["KnochenAuge"](image)
256
+ filtered_preds = [p for p in predictions if p['score'] >= conf_threshold]
257
+
258
+ if filtered_preds:
259
+ result_image = draw_boxes(image, filtered_preds)
260
+ st.image(result_image, use_container_width=True)
261
+ else:
262
+ st.image(image, use_container_width=True)
263
+
264
+ st.markdown('</div>', unsafe_allow_html=True)
265
+ st.markdown('</div>', unsafe_allow_html=True)
266
 
267
  if __name__ == "__main__":
268
  main()