yassonee commited on
Commit
1d4ce47
·
verified ·
1 Parent(s): 5db7880

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -102
app.py CHANGED
@@ -1,6 +1,9 @@
1
  import streamlit as st
2
  from transformers import pipeline
3
  from PIL import Image, ImageDraw
 
 
 
4
 
5
  st.set_page_config(
6
  page_title="Fraktur Detektion",
@@ -8,39 +11,90 @@ st.set_page_config(
8
  initial_sidebar_state="collapsed"
9
  )
10
 
 
11
  st.markdown("""
12
  <style>
13
  .stApp {
 
14
  padding: 0 !important;
15
- height: 100vh !important;
16
- overflow: hidden !important;
17
  }
18
-
19
  .block-container {
20
- padding: 0.25rem !important;
21
  max-width: 100% !important;
22
  }
23
-
24
- .stImage > img {
25
- max-height: 150px !important;
26
- object-fit: contain !important;
27
- }
28
-
29
- h2, h3 {
30
- font-size: 0.9rem !important;
31
- }
32
-
33
- .result-box {
34
- font-size: 0.8rem !important;
35
- margin: 0.2rem 0 !important;
36
- }
37
-
38
- .center-container {
39
  display: flex;
40
  flex-direction: column;
41
  align-items: center;
42
  justify-content: center;
43
- height: 100%;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  }
45
  </style>
46
  """, unsafe_allow_html=True)
@@ -63,101 +117,104 @@ def translate_label(label):
63
  }
64
  return translations.get(label.lower(), label)
65
 
66
- def draw_boxes(image, predictions):
67
  draw = ImageDraw.Draw(image)
68
  for pred in predictions:
69
  box = pred['box']
70
  label = f"{translate_label(pred['label'])} ({pred['score']:.2%})"
71
  color = "#2563eb" if pred['score'] > 0.7 else "#eab308"
72
-
73
  draw.rectangle(
74
  [(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
75
  outline=color,
76
  width=2
77
  )
78
-
79
- # Ajouter des points de "chaleur" aux fractures détectées
80
- center_x = (box['xmin'] + box['xmax']) / 2
81
- center_y = (box['ymin'] + box['ymax']) / 2
82
- radius = 5
83
- draw.ellipse(
84
- [(center_x - radius, center_y - radius), (center_x + radius, center_y + radius)],
85
- fill=color
86
- )
87
-
88
- # Label plus compact
89
- draw.text((box['xmin'], box['ymin'] - 15), label, fill="white")
90
  return image
91
 
92
  def main():
93
  models = load_models()
94
-
95
- if "uploaded" not in st.session_state:
96
- st.session_state["uploaded"] = False
97
-
98
- if not st.session_state["uploaded"]:
99
- st.markdown("""
100
- <div class="center-container">
101
- <h2>📤 Röntgenbild Hochladen</h2>
102
- <p>Bitte laden Sie ein Röntgenbild hoch, um die Analyse zu starten.</p>
103
- </div>
104
- """, unsafe_allow_html=True)
105
- uploaded_file = st.file_uploader("Röntgenbild auswählen", type=['png', 'jpg', 'jpeg'], label_visibility="collapsed")
106
-
107
- if uploaded_file:
108
- st.session_state["uploaded"] = True
109
- st.session_state["file"] = uploaded_file
110
- st.session_state["analyze"] = False
111
- else:
112
- uploaded_file = st.session_state["file"]
113
-
114
- if not st.session_state.get("analyze", False):
115
- if st.button("🔍 Analyse starten"):
116
- st.session_state["analyze"] = True
117
-
118
- if st.session_state["analyze"]:
119
- col1, col2, col3 = st.columns([1, 1.5, 1])
120
-
121
- with col1:
122
- st.markdown("### 🎯 KI-Analyse")
123
-
124
- st.markdown("**🛡️ Der KnochenWächter**")
125
- image = Image.open(uploaded_file)
126
- predictions_wachter = models["KnochenWächter"](image)
127
- for pred in predictions_wachter:
128
- score_color = "#22c55e" if pred['score'] > 0.7 else "#eab308"
129
- st.markdown(f"""
130
- <div class='result-box'>
131
- <span style='color: {score_color}; font-weight: 500;'>
132
- {pred['score']:.1%}
133
- </span> - {translate_label(pred['label'])}
134
- </div>
135
- """, unsafe_allow_html=True)
136
-
137
- st.markdown("**🎓 Der RöntgenMeister**")
138
- predictions_meister = models["RöntgenMeister"](image)
139
- for pred in predictions_meister:
140
- score_color = "#22c55e" if pred['score'] > 0.7 else "#eab308"
141
- st.markdown(f"""
142
- <div class='result-box'>
143
- <span style='color: {score_color}; font-weight: 500;'>
144
- {pred['score']:.1%}
145
- </span> - {translate_label(pred['label'])}
146
- </div>
147
- """, unsafe_allow_html=True)
148
-
149
- with col2:
150
- st.image(image, use_container_width=True)
151
-
152
- predictions_auge = models["KnochenAuge"](image)
153
- filtered_preds = [p for p in predictions_auge if p['score'] >= 0.6]
154
-
155
- if filtered_preds:
 
 
156
  with col3:
157
- st.markdown("### 👁️ Das KnochenAuge - Lokalisation")
158
- result_image = image.copy()
159
- result_image = draw_boxes(result_image, filtered_preds)
160
- st.image(result_image, use_container_width=True)
 
 
 
 
 
 
 
 
 
161
 
162
  if __name__ == "__main__":
163
- main()
 
1
  import streamlit as st
2
  from transformers import pipeline
3
  from PIL import Image, ImageDraw
4
+ import torch
5
+ from typing import List, Dict
6
+ import time
7
 
8
  st.set_page_config(
9
  page_title="Fraktur Detektion",
 
11
  initial_sidebar_state="collapsed"
12
  )
13
 
14
+ # CSS avec animations
15
  st.markdown("""
16
  <style>
17
  .stApp {
18
+ background-color: #f8fafc !important;
19
  padding: 0 !important;
 
 
20
  }
21
+
22
  .block-container {
23
+ padding: 0.5rem !important;
24
  max-width: 100% !important;
25
  }
26
+
27
+ .upload-section {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  display: flex;
29
  flex-direction: column;
30
  align-items: center;
31
  justify-content: center;
32
+ min-height: 50vh;
33
+ animation: fadeIn 0.5s ease-in;
34
+ }
35
+
36
+ .results-section {
37
+ animation: slideUp 0.5s ease-out;
38
+ }
39
+
40
+ .detection-box {
41
+ background: white;
42
+ border-radius: 8px;
43
+ padding: 1rem;
44
+ box-shadow: 0 1px 3px rgba(0,0,0,0.1);
45
+ margin-bottom: 1rem;
46
+ transform-origin: top;
47
+ animation: scaleIn 0.3s ease-out;
48
+ }
49
+
50
+ .result-item {
51
+ padding: 0.5rem;
52
+ border-radius: 4px;
53
+ margin: 0.25rem 0;
54
+ background: #f1f5f9;
55
+ animation: fadeIn 0.3s ease-out;
56
+ }
57
+
58
+ .image-grid {
59
+ display: grid;
60
+ grid-template-columns: repeat(auto-fill, minmax(200px, 1fr));
61
+ gap: 1rem;
62
+ margin-top: 1rem;
63
+ }
64
+
65
+ .image-container {
66
+ background: white;
67
+ border-radius: 8px;
68
+ padding: 0.5rem;
69
+ box-shadow: 0 1px 3px rgba(0,0,0,0.1);
70
+ animation: scaleIn 0.3s ease-out;
71
+ }
72
+
73
+ @keyframes fadeIn {
74
+ from { opacity: 0; }
75
+ to { opacity: 1; }
76
+ }
77
+
78
+ @keyframes slideUp {
79
+ from { transform: translateY(20px); opacity: 0; }
80
+ to { transform: translateY(0); opacity: 1; }
81
+ }
82
+
83
+ @keyframes scaleIn {
84
+ from { transform: scale(0.95); opacity: 0; }
85
+ to { transform: scale(1); opacity: 1; }
86
+ }
87
+
88
+ /* Compact image style */
89
+ .stImage > img {
90
+ max-height: 300px !important;
91
+ width: auto !important;
92
+ margin: 0 auto;
93
+ object-fit: contain;
94
+ }
95
+
96
+ #MainMenu, footer, header {
97
+ display: none !important;
98
  }
99
  </style>
100
  """, unsafe_allow_html=True)
 
117
  }
118
  return translations.get(label.lower(), label)
119
 
120
+ def draw_boxes(image: Image, predictions: List[Dict]) -> Image:
121
  draw = ImageDraw.Draw(image)
122
  for pred in predictions:
123
  box = pred['box']
124
  label = f"{translate_label(pred['label'])} ({pred['score']:.2%})"
125
  color = "#2563eb" if pred['score'] > 0.7 else "#eab308"
126
+
127
  draw.rectangle(
128
  [(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
129
  outline=color,
130
  width=2
131
  )
132
+
133
+ text_bbox = draw.textbbox((box['xmin'], box['ymin']-15), label)
134
+ draw.rectangle(text_bbox, fill=color)
135
+ draw.text((box['xmin'], box['ymin']-15), label, fill="white")
 
 
 
 
 
 
 
 
136
  return image
137
 
138
  def main():
139
  models = load_models()
140
+
141
+ if 'analyzed_images' not in st.session_state:
142
+ st.session_state.analyzed_images = []
143
+
144
+ # Section upload centrée
145
+ st.markdown('<div class="upload-section">', unsafe_allow_html=True)
146
+ st.markdown("### 📤 Röntgenbild Upload")
147
+ uploaded_files = st.file_uploader("", type=['png', 'jpg', 'jpeg'], accept_multiple_files=True)
148
+ conf_threshold = st.slider(
149
+ "Konfidenzschwelle",
150
+ min_value=0.0, max_value=1.0,
151
+ value=0.60, step=0.05
152
+ )
153
+ analyze_button = st.button("Analysieren")
154
+ st.markdown('</div>', unsafe_allow_html=True)
155
+
156
+ if analyze_button and uploaded_files:
157
+ st.markdown('<div class="results-section">', unsafe_allow_html=True)
158
+
159
+ for uploaded_file in uploaded_files:
160
+ image = Image.open(uploaded_file)
161
+
162
+ # Animation de chargement
163
+ with st.spinner("Analyse läuft..."):
164
+ time.sleep(0.5) # Animation effect
165
+
166
+ col1, col2, col3 = st.columns([1, 1, 1])
167
+
168
+ with col1:
169
+ st.markdown("### 📋 Bild Details")
170
+ st.image(image, use_column_width=True)
171
+
172
+ with col2:
173
+ st.markdown("### 🎯 KI-Analyse")
174
+
175
+ # KnochenWächter
176
+ with st.container():
177
+ st.markdown("#### 🛡️ KnochenWächter")
178
+ predictions = models["KnochenWächter"](image)
179
+ for pred in predictions:
180
+ if pred['score'] >= conf_threshold:
181
+ st.markdown(f"""
182
+ <div class="result-item">
183
+ <span style='color: {"#22c55e" if pred["score"] > 0.7 else "#eab308"}; font-weight: 500;'>
184
+ {pred['score']:.1%}
185
+ </span> - {translate_label(pred['label'])}
186
+ </div>
187
+ """, unsafe_allow_html=True)
188
+
189
+ # RöntgenMeister
190
+ with st.container():
191
+ st.markdown("#### 🎓 RöntgenMeister")
192
+ predictions = models["RöntgenMeister"](image)
193
+ for pred in predictions:
194
+ if pred['score'] >= conf_threshold:
195
+ st.markdown(f"""
196
+ <div class="result-item">
197
+ <span style='color: {"#22c55e" if pred["score"] > 0.7 else "#eab308"}; font-weight: 500;'>
198
+ {pred['score']:.1%}
199
+ </span> - {translate_label(pred['label'])}
200
+ </div>
201
+ """, unsafe_allow_html=True)
202
+
203
+ # Afficher la localisation uniquement si une fracture est détectée
204
  with col3:
205
+ predictions_location = models["KnochenAuge"](image)
206
+ fractures_detected = any(p['score'] >= conf_threshold and 'fracture' in p['label'].lower()
207
+ for p in predictions_location)
208
+
209
+ if fractures_detected:
210
+ st.markdown("### 🔍 Fraktur Lokalisation")
211
+ filtered_preds = [p for p in predictions_location if p['score'] >= conf_threshold]
212
+ if filtered_preds:
213
+ result_image = image.copy()
214
+ result_image = draw_boxes(result_image, filtered_preds)
215
+ st.image(result_image, use_column_width=True)
216
+
217
+ st.markdown('</div>', unsafe_allow_html=True)
218
 
219
  if __name__ == "__main__":
220
+ main()