yassonee commited on
Commit
f86faf1
·
verified ·
1 Parent(s): 592f7b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +213 -199
app.py CHANGED
@@ -1,105 +1,26 @@
1
- import streamlit as st
 
 
 
 
2
  from transformers import pipeline
3
  from PIL import Image, ImageDraw
 
 
4
  import numpy as np
5
- import colorsys
6
 
7
- st.set_page_config(
8
- page_title="Fraktur Detektion",
9
- layout="wide",
10
- initial_sidebar_state="collapsed"
11
- )
12
 
13
- st.markdown("""
14
- <style>
15
- .stApp {
16
- background: #f0f2f5 !important;
17
- }
18
-
19
- .block-container {
20
- padding-top: 0 !important;
21
- padding-bottom: 0 !important;
22
- max-width: 1400px !important;
23
- }
24
-
25
- .upload-container {
26
- background: white;
27
- padding: 1.5rem;
28
- border-radius: 10px;
29
- box-shadow: 0 2px 4px rgba(0,0,0,0.1);
30
- margin-bottom: 1rem;
31
- text-align: center;
32
- }
33
-
34
- .results-container {
35
- background: white;
36
- padding: 1.5rem;
37
- border-radius: 10px;
38
- box-shadow: 0 2px 4px rgba(0,0,0,0.1);
39
- }
40
-
41
- .result-box {
42
- background: #f8f9fa;
43
- padding: 0.75rem;
44
- border-radius: 8px;
45
- margin: 0.5rem 0;
46
- border: 1px solid #e9ecef;
47
- }
48
-
49
- h1, h2, h3, h4, p {
50
- color: #1a1a1a !important;
51
- margin: 0.5rem 0 !important;
52
- }
53
-
54
- .stImage {
55
- background: white;
56
- padding: 0.5rem;
57
- border-radius: 8px;
58
- box-shadow: 0 1px 3px rgba(0,0,0,0.1);
59
- }
60
-
61
- .stImage > img {
62
- max-height: 300px !important;
63
- width: auto !important;
64
- margin: 0 auto !important;
65
- display: block !important;
66
- }
67
-
68
- [data-testid="stFileUploader"] {
69
- width: 100% !important;
70
- }
71
-
72
- .stFileUploaderFileName {
73
- color: #1a1a1a !important;
74
- }
75
-
76
- .stButton > button {
77
- width: 200px;
78
- background-color: #f8f9fa !important;
79
- color: #1a1a1a !important;
80
- border: 1px solid #e9ecef !important;
81
- padding: 0.5rem 1rem !important;
82
- border-radius: 5px !important;
83
- transition: all 0.3s ease !important;
84
- }
85
-
86
- .stButton > button:hover {
87
- background-color: #e9ecef !important;
88
- transform: translateY(-1px);
89
- }
90
-
91
- #MainMenu, footer, header, [data-testid="stToolbar"] {
92
- display: none !important;
93
- }
94
-
95
- /* Hide deprecation warning */
96
- [data-testid="stExpander"], .element-container:has(>.stAlert) {
97
- display: none !important;
98
- }
99
- </style>
100
- """, unsafe_allow_html=True)
101
 
102
- @st.cache_resource
103
  def load_models():
104
  return {
105
  "KnochenAuge": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
@@ -108,6 +29,8 @@ def load_models():
108
  model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
109
  }
110
 
 
 
111
  def translate_label(label):
112
  translations = {
113
  "fracture": "Knochenbruch",
@@ -126,21 +49,17 @@ def create_heatmap_overlay(image, box, score):
126
  x1, y1 = box['xmin'], box['ymin']
127
  x2, y2 = box['xmax'], box['ymax']
128
 
129
- # Couleur basée sur le score
130
  if score > 0.8:
131
- fill_color = (255, 0, 0, 100) # Rouge
132
  border_color = (255, 0, 0, 255)
133
  elif score > 0.6:
134
- fill_color = (255, 165, 0, 100) # Orange
135
  border_color = (255, 165, 0, 255)
136
  else:
137
- fill_color = (255, 255, 0, 100) # Jaune
138
  border_color = (255, 255, 0, 255)
139
 
140
- # Rectangle semi-transparent
141
  draw.rectangle([x1, y1, x2, y2], fill=fill_color)
142
-
143
- # Bordure
144
  draw.rectangle([x1, y1, x2, y2], outline=border_color, width=2)
145
 
146
  return overlay
@@ -152,20 +71,16 @@ def draw_boxes(image, predictions):
152
  box = pred['box']
153
  score = pred['score']
154
 
155
- # Création de l'overlay
156
  overlay = create_heatmap_overlay(image, box, score)
157
  result_image = Image.alpha_composite(result_image, overlay)
158
 
159
- # Ajout du texte
160
  draw = ImageDraw.Draw(result_image)
161
  temp = 36.5 + (score * 2.5)
162
  label = f"{translate_label(pred['label'])} ({score:.1%} • {temp:.1f}°C)"
163
 
164
- # Fond noir pour le texte
165
  text_bbox = draw.textbbox((box['xmin'], box['ymin']-20), label)
166
  draw.rectangle(text_bbox, fill=(0, 0, 0, 180))
167
 
168
- # Texte en blanc
169
  draw.text(
170
  (box['xmin'], box['ymin']-20),
171
  label,
@@ -174,101 +89,200 @@ def draw_boxes(image, predictions):
174
 
175
  return result_image
176
 
177
- def main():
178
- models = load_models()
179
-
180
- with st.container():
181
- st.write("### 📤 Röntgenbild hochladen")
182
- uploaded_file = st.file_uploader("Bild auswählen", type=['png', 'jpg', 'jpeg'], label_visibility="collapsed")
183
-
184
- col1, col2 = st.columns([2, 1])
185
- with col1:
186
- conf_threshold = st.slider(
187
- "Konfidenzschwelle",
188
- min_value=0.0, max_value=1.0,
189
- value=0.60, step=0.05,
190
- label_visibility="visible"
191
- )
192
- with col2:
193
- analyze_button = st.button("Analysieren")
194
-
195
- if uploaded_file and analyze_button:
196
- with st.spinner("Bild wird analysiert..."):
197
- image = Image.open(uploaded_file)
198
- results_container = st.container()
199
-
200
- predictions_watcher = models["KnochenWächter"](image)
201
- predictions_master = models["RöntgenMeister"](image)
202
- predictions_locator = models["KnochenAuge"](image)
203
-
204
- has_fracture = False
205
- max_fracture_score = 0
206
- filtered_locations = [p for p in predictions_locator
207
- if p['score'] >= conf_threshold]
208
-
209
- for pred in predictions_watcher:
210
- if pred['score'] >= conf_threshold and 'fracture' in pred['label'].lower():
211
- has_fracture = True
212
- max_fracture_score = max(max_fracture_score, pred['score'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
214
- with results_container:
215
- st.write("### 🔍 Analyse Ergebnisse")
216
- col1, col2 = st.columns(2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
- with col1:
219
- st.write("#### 🤖 KI-Diagnose")
220
-
221
- st.markdown("#### 🛡️ KnochenWächter")
222
- # Afficher tous les résultats de KnochenWächter
223
- for pred in predictions_watcher:
224
- confidence_color = '#0066cc' if pred['score'] > 0.7 else '#ffa500'
225
- label_lower = pred['label'].lower()
226
- # Mettre à jour max_fracture_score seulement pour les fractures
227
- if pred['score'] >= conf_threshold and 'fracture' in label_lower:
228
- has_fracture = True
229
- max_fracture_score = max(max_fracture_score, pred['score'])
230
- # Afficher tous les résultats
231
- st.markdown(f"""
232
- <div class="result-box" style="color: #1a1a1a;">
233
- <span style="color: {confidence_color}; font-weight: 500;">
234
- {pred['score']:.1%}
235
- </span> - {translate_label(pred['label'])}
236
- </div>
237
- """, unsafe_allow_html=True)
238
-
239
- st.markdown("#### 🎓 RöntgenMeister")
240
- # Afficher tous les résultats de RöntgenMeister
241
- for pred in predictions_master:
242
- confidence_color = '#0066cc' if pred['score'] > 0.7 else '#ffa500'
243
- st.markdown(f"""
244
- <div class="result-box" style="color: #1a1a1a;">
245
- <span style="color: {confidence_color}; font-weight: 500;">
246
- {pred['score']:.1%}
247
- </span> - {translate_label(pred['label'])}
248
- </div>
249
- """, unsafe_allow_html=True)
250
-
251
- if max_fracture_score > 0:
252
- st.write("#### 📊 Wahrscheinlichkeit")
253
- no_fracture_prob = 1 - max_fracture_score
254
- st.markdown(f"""
255
- <div class="result-box" style="color: #1a1a1a;">
256
- Knochenbruch: <strong style="color: #0066cc">{max_fracture_score:.1%}</strong><br>
257
- Kein Knochenbruch: <strong style="color: #ffa500">{no_fracture_prob:.1%}</strong>
258
- </div>
259
- """, unsafe_allow_html=True)
260
 
261
- with col2:
262
- predictions = models["KnochenAuge"](image)
263
- filtered_preds = [p for p in predictions if p['score'] >= conf_threshold]
 
 
 
 
 
 
 
 
264
 
265
- if filtered_preds:
266
- st.write("#### 🎯 Fraktur Lokalisation")
267
- result_image = draw_boxes(image, filtered_preds)
268
- st.image(result_image, use_container_width=True)
269
- else:
270
- st.write("#### 🖼️ Röntgenbild")
271
- st.image(image, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
  if __name__ == "__main__":
274
- main()
 
1
+ from fastapi import FastAPI, File, UploadFile, Form
2
+ from fastapi.responses import HTMLResponse, JSONResponse
3
+ from fastapi.staticfiles import StaticFiles
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ import uvicorn
6
  from transformers import pipeline
7
  from PIL import Image, ImageDraw
8
+ import io
9
+ import base64
10
  import numpy as np
 
11
 
12
+ app = FastAPI()
 
 
 
 
13
 
14
+ # Configuration CORS pour éviter les problèmes de navigateur
15
+ app.add_middleware(
16
+ CORSMiddleware,
17
+ allow_origins=["*"],
18
+ allow_credentials=True,
19
+ allow_methods=["*"],
20
+ allow_headers=["*"],
21
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ # Chargement des modèles
24
  def load_models():
25
  return {
26
  "KnochenAuge": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
 
29
  model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
30
  }
31
 
32
+ models = load_models()
33
+
34
  def translate_label(label):
35
  translations = {
36
  "fracture": "Knochenbruch",
 
49
  x1, y1 = box['xmin'], box['ymin']
50
  x2, y2 = box['xmax'], box['ymax']
51
 
 
52
  if score > 0.8:
53
+ fill_color = (255, 0, 0, 100)
54
  border_color = (255, 0, 0, 255)
55
  elif score > 0.6:
56
+ fill_color = (255, 165, 0, 100)
57
  border_color = (255, 165, 0, 255)
58
  else:
59
+ fill_color = (255, 255, 0, 100)
60
  border_color = (255, 255, 0, 255)
61
 
 
62
  draw.rectangle([x1, y1, x2, y2], fill=fill_color)
 
 
63
  draw.rectangle([x1, y1, x2, y2], outline=border_color, width=2)
64
 
65
  return overlay
 
71
  box = pred['box']
72
  score = pred['score']
73
 
 
74
  overlay = create_heatmap_overlay(image, box, score)
75
  result_image = Image.alpha_composite(result_image, overlay)
76
 
 
77
  draw = ImageDraw.Draw(result_image)
78
  temp = 36.5 + (score * 2.5)
79
  label = f"{translate_label(pred['label'])} ({score:.1%} • {temp:.1f}°C)"
80
 
 
81
  text_bbox = draw.textbbox((box['xmin'], box['ymin']-20), label)
82
  draw.rectangle(text_bbox, fill=(0, 0, 0, 180))
83
 
 
84
  draw.text(
85
  (box['xmin'], box['ymin']-20),
86
  label,
 
89
 
90
  return result_image
91
 
92
+ # Interface HTML de base
93
+ @app.get("/", response_class=HTMLResponse)
94
+ async def read_root():
95
+ return """
96
+ <!DOCTYPE html>
97
+ <html>
98
+ <head>
99
+ <title>Fraktur Detektion</title>
100
+ <style>
101
+ body {
102
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
103
+ background: #f0f2f5;
104
+ margin: 0;
105
+ padding: 20px;
106
+ }
107
+ .container {
108
+ max-width: 1200px;
109
+ margin: 0 auto;
110
+ background: white;
111
+ padding: 20px;
112
+ border-radius: 10px;
113
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
114
+ }
115
+ .result-box {
116
+ background: #f8f9fa;
117
+ padding: 15px;
118
+ border-radius: 8px;
119
+ margin: 10px 0;
120
+ border: 1px solid #e9ecef;
121
+ }
122
+ .button {
123
+ background: #f8f9fa;
124
+ border: 1px solid #e9ecef;
125
+ padding: 10px 20px;
126
+ border-radius: 5px;
127
+ cursor: pointer;
128
+ transition: all 0.3s ease;
129
+ }
130
+ .button:hover {
131
+ background: #e9ecef;
132
+ transform: translateY(-1px);
133
+ }
134
+ .row {
135
+ display: flex;
136
+ margin: 20px -10px;
137
+ }
138
+ .col {
139
+ flex: 1;
140
+ padding: 0 10px;
141
+ }
142
+ img {
143
+ max-width: 100%;
144
+ border-radius: 8px;
145
+ }
146
+ .loading {
147
+ display: none;
148
+ text-align: center;
149
+ padding: 20px;
150
+ }
151
+ </style>
152
+ </head>
153
+ <body>
154
+ <div class="container">
155
+ <h1>📤 Fraktur Detektion</h1>
156
 
157
+ <form id="uploadForm">
158
+ <input type="file" id="image" name="image" accept="image/*">
159
+ <input type="range" id="threshold" min="0" max="1" step="0.05" value="0.6">
160
+ <label for="threshold">Konfidenzschwelle: <span id="thresholdValue">0.60</span></label>
161
+ <button type="submit" class="button">Analysieren</button>
162
+ </form>
163
+
164
+ <div class="loading" id="loading">
165
+ Bild wird analysiert...
166
+ </div>
167
+
168
+ <div class="row">
169
+ <div class="col" id="results"></div>
170
+ <div class="col" id="imageResult"></div>
171
+ </div>
172
+ </div>
173
+
174
+ <script>
175
+ document.getElementById('threshold').addEventListener('input', function(e) {
176
+ document.getElementById('thresholdValue').textContent =
177
+ parseFloat(e.target.value).toFixed(2);
178
+ });
179
+
180
+ document.getElementById('uploadForm').addEventListener('submit', async function(e) {
181
+ e.preventDefault();
182
 
183
+ const formData = new FormData();
184
+ formData.append('image', document.getElementById('image').files[0]);
185
+ formData.append('threshold', document.getElementById('threshold').value);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
+ document.getElementById('loading').style.display = 'block';
188
+ document.getElementById('results').innerHTML = '';
189
+ document.getElementById('imageResult').innerHTML = '';
190
+
191
+ try {
192
+ const response = await fetch('/analyze', {
193
+ method: 'POST',
194
+ body: formData
195
+ });
196
+
197
+ const data = await response.json();
198
 
199
+ document.getElementById('results').innerHTML = data.results;
200
+ document.getElementById('imageResult').innerHTML =
201
+ `<img src="data:image/jpeg;base64,${data.image}" alt="Analyzed image">`;
202
+ } catch (error) {
203
+ console.error('Error:', error);
204
+ } finally {
205
+ document.getElementById('loading').style.display = 'none';
206
+ }
207
+ });
208
+ </script>
209
+ </body>
210
+ </html>
211
+ """
212
+
213
+ @app.post("/analyze")
214
+ async def analyze_image(image: UploadFile = File(...), threshold: float = Form(0.6)):
215
+ # Lecture de l'image
216
+ image_data = await image.read()
217
+ image = Image.open(io.BytesIO(image_data))
218
+
219
+ # Analyse
220
+ predictions_watcher = models["KnochenWächter"](image)
221
+ predictions_master = models["RöntgenMeister"](image)
222
+ predictions_locator = models["KnochenAuge"](image)
223
+
224
+ has_fracture = False
225
+ max_fracture_score = 0
226
+
227
+ # Génération du HTML pour les résultats
228
+ results_html = "<h2>🔍 Analyse Ergebnisse</h2>"
229
+
230
+ # KnochenWächter results
231
+ results_html += "<h3>🛡️ KnochenWächter</h3>"
232
+ for pred in predictions_watcher:
233
+ confidence_color = '#0066cc' if pred['score'] > 0.7 else '#ffa500'
234
+ if pred['score'] >= threshold and 'fracture' in pred['label'].lower():
235
+ has_fracture = True
236
+ max_fracture_score = max(max_fracture_score, pred['score'])
237
+ results_html += f"""
238
+ <div class="result-box">
239
+ <span style="color: {confidence_color}; font-weight: 500;">
240
+ {pred['score']:.1%}
241
+ </span> - {translate_label(pred['label'])}
242
+ </div>
243
+ """
244
+
245
+ # RöntgenMeister results
246
+ results_html += "<h3>🎓 RöntgenMeister</h3>"
247
+ for pred in predictions_master:
248
+ confidence_color = '#0066cc' if pred['score'] > 0.7 else '#ffa500'
249
+ results_html += f"""
250
+ <div class="result-box">
251
+ <span style="color: {confidence_color}; font-weight: 500;">
252
+ {pred['score']:.1%}
253
+ </span> - {translate_label(pred['label'])}
254
+ </div>
255
+ """
256
+
257
+ # Probabilité si fracture détectée
258
+ if max_fracture_score > 0:
259
+ no_fracture_prob = 1 - max_fracture_score
260
+ results_html += f"""
261
+ <h3>📊 Wahrscheinlichkeit</h3>
262
+ <div class="result-box">
263
+ Knochenbruch: <strong style="color: #0066cc">{max_fracture_score:.1%}</strong><br>
264
+ Kein Knochenbruch: <strong style="color: #ffa500">{no_fracture_prob:.1%}</strong>
265
+ </div>
266
+ """
267
+
268
+ # Traitement de l'image
269
+ predictions = models["KnochenAuge"](image)
270
+ filtered_preds = [p for p in predictions if p['score'] >= threshold]
271
+
272
+ if filtered_preds:
273
+ result_image = draw_boxes(image, filtered_preds)
274
+ else:
275
+ result_image = image
276
+
277
+ # Conversion de l'image en base64
278
+ buffered = io.BytesIO()
279
+ result_image.save(buffered, format="JPEG")
280
+ img_str = base64.b64encode(buffered.getvalue()).decode()
281
+
282
+ return JSONResponse({
283
+ "results": results_html,
284
+ "image": img_str
285
+ })
286
 
287
  if __name__ == "__main__":
288
+ uvicorn.run(app, host="0.0.0.0", port=7860)