yassonee commited on
Commit
c9e0580
·
verified ·
1 Parent(s): f86faf1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +199 -213
app.py CHANGED
@@ -1,26 +1,105 @@
1
- from fastapi import FastAPI, File, UploadFile, Form
2
- from fastapi.responses import HTMLResponse, JSONResponse
3
- from fastapi.staticfiles import StaticFiles
4
- from fastapi.middleware.cors import CORSMiddleware
5
- import uvicorn
6
  from transformers import pipeline
7
  from PIL import Image, ImageDraw
8
- import io
9
- import base64
10
  import numpy as np
 
11
 
12
- app = FastAPI()
13
-
14
- # Configuration CORS pour éviter les problèmes de navigateur
15
- app.add_middleware(
16
- CORSMiddleware,
17
- allow_origins=["*"],
18
- allow_credentials=True,
19
- allow_methods=["*"],
20
- allow_headers=["*"],
21
  )
22
 
23
- # Chargement des modèles
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def load_models():
25
  return {
26
  "KnochenAuge": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
@@ -29,8 +108,6 @@ def load_models():
29
  model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
30
  }
31
 
32
- models = load_models()
33
-
34
  def translate_label(label):
35
  translations = {
36
  "fracture": "Knochenbruch",
@@ -49,17 +126,21 @@ def create_heatmap_overlay(image, box, score):
49
  x1, y1 = box['xmin'], box['ymin']
50
  x2, y2 = box['xmax'], box['ymax']
51
 
 
52
  if score > 0.8:
53
- fill_color = (255, 0, 0, 100)
54
  border_color = (255, 0, 0, 255)
55
  elif score > 0.6:
56
- fill_color = (255, 165, 0, 100)
57
  border_color = (255, 165, 0, 255)
58
  else:
59
- fill_color = (255, 255, 0, 100)
60
  border_color = (255, 255, 0, 255)
61
 
 
62
  draw.rectangle([x1, y1, x2, y2], fill=fill_color)
 
 
63
  draw.rectangle([x1, y1, x2, y2], outline=border_color, width=2)
64
 
65
  return overlay
@@ -71,16 +152,20 @@ def draw_boxes(image, predictions):
71
  box = pred['box']
72
  score = pred['score']
73
 
 
74
  overlay = create_heatmap_overlay(image, box, score)
75
  result_image = Image.alpha_composite(result_image, overlay)
76
 
 
77
  draw = ImageDraw.Draw(result_image)
78
  temp = 36.5 + (score * 2.5)
79
  label = f"{translate_label(pred['label'])} ({score:.1%} • {temp:.1f}°C)"
80
 
 
81
  text_bbox = draw.textbbox((box['xmin'], box['ymin']-20), label)
82
  draw.rectangle(text_bbox, fill=(0, 0, 0, 180))
83
 
 
84
  draw.text(
85
  (box['xmin'], box['ymin']-20),
86
  label,
@@ -89,200 +174,101 @@ def draw_boxes(image, predictions):
89
 
90
  return result_image
91
 
92
- # Interface HTML de base
93
- @app.get("/", response_class=HTMLResponse)
94
- async def read_root():
95
- return """
96
- <!DOCTYPE html>
97
- <html>
98
- <head>
99
- <title>Fraktur Detektion</title>
100
- <style>
101
- body {
102
- font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
103
- background: #f0f2f5;
104
- margin: 0;
105
- padding: 20px;
106
- }
107
- .container {
108
- max-width: 1200px;
109
- margin: 0 auto;
110
- background: white;
111
- padding: 20px;
112
- border-radius: 10px;
113
- box-shadow: 0 2px 4px rgba(0,0,0,0.1);
114
- }
115
- .result-box {
116
- background: #f8f9fa;
117
- padding: 15px;
118
- border-radius: 8px;
119
- margin: 10px 0;
120
- border: 1px solid #e9ecef;
121
- }
122
- .button {
123
- background: #f8f9fa;
124
- border: 1px solid #e9ecef;
125
- padding: 10px 20px;
126
- border-radius: 5px;
127
- cursor: pointer;
128
- transition: all 0.3s ease;
129
- }
130
- .button:hover {
131
- background: #e9ecef;
132
- transform: translateY(-1px);
133
- }
134
- .row {
135
- display: flex;
136
- margin: 20px -10px;
137
- }
138
- .col {
139
- flex: 1;
140
- padding: 0 10px;
141
- }
142
- img {
143
- max-width: 100%;
144
- border-radius: 8px;
145
- }
146
- .loading {
147
- display: none;
148
- text-align: center;
149
- padding: 20px;
150
- }
151
- </style>
152
- </head>
153
- <body>
154
- <div class="container">
155
- <h1>📤 Fraktur Detektion</h1>
156
-
157
- <form id="uploadForm">
158
- <input type="file" id="image" name="image" accept="image/*">
159
- <input type="range" id="threshold" min="0" max="1" step="0.05" value="0.6">
160
- <label for="threshold">Konfidenzschwelle: <span id="thresholdValue">0.60</span></label>
161
- <button type="submit" class="button">Analysieren</button>
162
- </form>
163
-
164
- <div class="loading" id="loading">
165
- Bild wird analysiert...
166
- </div>
167
-
168
- <div class="row">
169
- <div class="col" id="results"></div>
170
- <div class="col" id="imageResult"></div>
171
- </div>
172
- </div>
173
-
174
- <script>
175
- document.getElementById('threshold').addEventListener('input', function(e) {
176
- document.getElementById('thresholdValue').textContent =
177
- parseFloat(e.target.value).toFixed(2);
178
- });
179
 
180
- document.getElementById('uploadForm').addEventListener('submit', async function(e) {
181
- e.preventDefault();
182
-
183
- const formData = new FormData();
184
- formData.append('image', document.getElementById('image').files[0]);
185
- formData.append('threshold', document.getElementById('threshold').value);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
- document.getElementById('loading').style.display = 'block';
188
- document.getElementById('results').innerHTML = '';
189
- document.getElementById('imageResult').innerHTML = '';
190
-
191
- try {
192
- const response = await fetch('/analyze', {
193
- method: 'POST',
194
- body: formData
195
- });
196
 
197
- const data = await response.json();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
- document.getElementById('results').innerHTML = data.results;
200
- document.getElementById('imageResult').innerHTML =
201
- `<img src="data:image/jpeg;base64,${data.image}" alt="Analyzed image">`;
202
- } catch (error) {
203
- console.error('Error:', error);
204
- } finally {
205
- document.getElementById('loading').style.display = 'none';
206
- }
207
- });
208
- </script>
209
- </body>
210
- </html>
211
- """
212
-
213
- @app.post("/analyze")
214
- async def analyze_image(image: UploadFile = File(...), threshold: float = Form(0.6)):
215
- # Lecture de l'image
216
- image_data = await image.read()
217
- image = Image.open(io.BytesIO(image_data))
218
-
219
- # Analyse
220
- predictions_watcher = models["KnochenWächter"](image)
221
- predictions_master = models["RöntgenMeister"](image)
222
- predictions_locator = models["KnochenAuge"](image)
223
-
224
- has_fracture = False
225
- max_fracture_score = 0
226
-
227
- # Génération du HTML pour les résultats
228
- results_html = "<h2>🔍 Analyse Ergebnisse</h2>"
229
-
230
- # KnochenWächter results
231
- results_html += "<h3>🛡️ KnochenWächter</h3>"
232
- for pred in predictions_watcher:
233
- confidence_color = '#0066cc' if pred['score'] > 0.7 else '#ffa500'
234
- if pred['score'] >= threshold and 'fracture' in pred['label'].lower():
235
- has_fracture = True
236
- max_fracture_score = max(max_fracture_score, pred['score'])
237
- results_html += f"""
238
- <div class="result-box">
239
- <span style="color: {confidence_color}; font-weight: 500;">
240
- {pred['score']:.1%}
241
- </span> - {translate_label(pred['label'])}
242
- </div>
243
- """
244
-
245
- # RöntgenMeister results
246
- results_html += "<h3>🎓 RöntgenMeister</h3>"
247
- for pred in predictions_master:
248
- confidence_color = '#0066cc' if pred['score'] > 0.7 else '#ffa500'
249
- results_html += f"""
250
- <div class="result-box">
251
- <span style="color: {confidence_color}; font-weight: 500;">
252
- {pred['score']:.1%}
253
- </span> - {translate_label(pred['label'])}
254
- </div>
255
- """
256
-
257
- # Probabilité si fracture détectée
258
- if max_fracture_score > 0:
259
- no_fracture_prob = 1 - max_fracture_score
260
- results_html += f"""
261
- <h3>📊 Wahrscheinlichkeit</h3>
262
- <div class="result-box">
263
- Knochenbruch: <strong style="color: #0066cc">{max_fracture_score:.1%}</strong><br>
264
- Kein Knochenbruch: <strong style="color: #ffa500">{no_fracture_prob:.1%}</strong>
265
- </div>
266
- """
267
-
268
- # Traitement de l'image
269
- predictions = models["KnochenAuge"](image)
270
- filtered_preds = [p for p in predictions if p['score'] >= threshold]
271
-
272
- if filtered_preds:
273
- result_image = draw_boxes(image, filtered_preds)
274
- else:
275
- result_image = image
276
-
277
- # Conversion de l'image en base64
278
- buffered = io.BytesIO()
279
- result_image.save(buffered, format="JPEG")
280
- img_str = base64.b64encode(buffered.getvalue()).decode()
281
-
282
- return JSONResponse({
283
- "results": results_html,
284
- "image": img_str
285
- })
286
 
287
  if __name__ == "__main__":
288
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ import streamlit as st
 
 
 
 
2
  from transformers import pipeline
3
  from PIL import Image, ImageDraw
 
 
4
  import numpy as np
5
+ import colorsys
6
 
7
+ st.set_page_config(
8
+ page_title="Fraktur Detektion",
9
+ layout="wide",
10
+ initial_sidebar_state="collapsed"
 
 
 
 
 
11
  )
12
 
13
+ st.markdown("""
14
+ <style>
15
+ .stApp {
16
+ background: #f0f2f5 !important;
17
+ }
18
+
19
+ .block-container {
20
+ padding-top: 0 !important;
21
+ padding-bottom: 0 !important;
22
+ max-width: 1400px !important;
23
+ }
24
+
25
+ .upload-container {
26
+ background: white;
27
+ padding: 1.5rem;
28
+ border-radius: 10px;
29
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
30
+ margin-bottom: 1rem;
31
+ text-align: center;
32
+ }
33
+
34
+ .results-container {
35
+ background: white;
36
+ padding: 1.5rem;
37
+ border-radius: 10px;
38
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
39
+ }
40
+
41
+ .result-box {
42
+ background: #f8f9fa;
43
+ padding: 0.75rem;
44
+ border-radius: 8px;
45
+ margin: 0.5rem 0;
46
+ border: 1px solid #e9ecef;
47
+ }
48
+
49
+ h1, h2, h3, h4, p {
50
+ color: #1a1a1a !important;
51
+ margin: 0.5rem 0 !important;
52
+ }
53
+
54
+ .stImage {
55
+ background: white;
56
+ padding: 0.5rem;
57
+ border-radius: 8px;
58
+ box-shadow: 0 1px 3px rgba(0,0,0,0.1);
59
+ }
60
+
61
+ .stImage > img {
62
+ max-height: 300px !important;
63
+ width: auto !important;
64
+ margin: 0 auto !important;
65
+ display: block !important;
66
+ }
67
+
68
+ [data-testid="stFileUploader"] {
69
+ width: 100% !important;
70
+ }
71
+
72
+ .stFileUploaderFileName {
73
+ color: #1a1a1a !important;
74
+ }
75
+
76
+ .stButton > button {
77
+ width: 200px;
78
+ background-color: #f8f9fa !important;
79
+ color: #1a1a1a !important;
80
+ border: 1px solid #e9ecef !important;
81
+ padding: 0.5rem 1rem !important;
82
+ border-radius: 5px !important;
83
+ transition: all 0.3s ease !important;
84
+ }
85
+
86
+ .stButton > button:hover {
87
+ background-color: #e9ecef !important;
88
+ transform: translateY(-1px);
89
+ }
90
+
91
+ #MainMenu, footer, header, [data-testid="stToolbar"] {
92
+ display: none !important;
93
+ }
94
+
95
+ /* Hide deprecation warning */
96
+ [data-testid="stExpander"], .element-container:has(>.stAlert) {
97
+ display: none !important;
98
+ }
99
+ </style>
100
+ """, unsafe_allow_html=True)
101
+
102
+ @st.cache_resource
103
  def load_models():
104
  return {
105
  "KnochenAuge": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
 
108
  model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
109
  }
110
 
 
 
111
  def translate_label(label):
112
  translations = {
113
  "fracture": "Knochenbruch",
 
126
  x1, y1 = box['xmin'], box['ymin']
127
  x2, y2 = box['xmax'], box['ymax']
128
 
129
+ # Couleur basée sur le score
130
  if score > 0.8:
131
+ fill_color = (255, 0, 0, 100) # Rouge
132
  border_color = (255, 0, 0, 255)
133
  elif score > 0.6:
134
+ fill_color = (255, 165, 0, 100) # Orange
135
  border_color = (255, 165, 0, 255)
136
  else:
137
+ fill_color = (255, 255, 0, 100) # Jaune
138
  border_color = (255, 255, 0, 255)
139
 
140
+ # Rectangle semi-transparent
141
  draw.rectangle([x1, y1, x2, y2], fill=fill_color)
142
+
143
+ # Bordure
144
  draw.rectangle([x1, y1, x2, y2], outline=border_color, width=2)
145
 
146
  return overlay
 
152
  box = pred['box']
153
  score = pred['score']
154
 
155
+ # Création de l'overlay
156
  overlay = create_heatmap_overlay(image, box, score)
157
  result_image = Image.alpha_composite(result_image, overlay)
158
 
159
+ # Ajout du texte
160
  draw = ImageDraw.Draw(result_image)
161
  temp = 36.5 + (score * 2.5)
162
  label = f"{translate_label(pred['label'])} ({score:.1%} • {temp:.1f}°C)"
163
 
164
+ # Fond noir pour le texte
165
  text_bbox = draw.textbbox((box['xmin'], box['ymin']-20), label)
166
  draw.rectangle(text_bbox, fill=(0, 0, 0, 180))
167
 
168
+ # Texte en blanc
169
  draw.text(
170
  (box['xmin'], box['ymin']-20),
171
  label,
 
174
 
175
  return result_image
176
 
177
+ def main():
178
+ models = load_models()
179
+
180
+ with st.container():
181
+ st.write("### 📤 Röntgenbild hochladen")
182
+ uploaded_file = st.file_uploader("Bild auswählen", type=['png', 'jpg', 'jpeg'], label_visibility="collapsed")
183
+
184
+ col1, col2 = st.columns([2, 1])
185
+ with col1:
186
+ conf_threshold = st.slider(
187
+ "Konfidenzschwelle",
188
+ min_value=0.0, max_value=1.0,
189
+ value=0.60, step=0.05,
190
+ label_visibility="visible"
191
+ )
192
+ with col2:
193
+ analyze_button = st.button("Analysieren")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
+ if uploaded_file and analyze_button:
196
+ with st.spinner("Bild wird analysiert..."):
197
+ image = Image.open(uploaded_file)
198
+ results_container = st.container()
199
+
200
+ predictions_watcher = models["KnochenWächter"](image)
201
+ predictions_master = models["RöntgenMeister"](image)
202
+ predictions_locator = models["KnochenAuge"](image)
203
+
204
+ has_fracture = False
205
+ max_fracture_score = 0
206
+ filtered_locations = [p for p in predictions_locator
207
+ if p['score'] >= conf_threshold]
208
+
209
+ for pred in predictions_watcher:
210
+ if pred['score'] >= conf_threshold and 'fracture' in pred['label'].lower():
211
+ has_fracture = True
212
+ max_fracture_score = max(max_fracture_score, pred['score'])
213
+
214
+ with results_container:
215
+ st.write("### 🔍 Analyse Ergebnisse")
216
+ col1, col2 = st.columns(2)
217
 
218
+ with col1:
219
+ st.write("#### 🤖 KI-Diagnose")
 
 
 
 
 
 
 
220
 
221
+ st.markdown("#### 🛡️ KnochenWächter")
222
+ # Afficher tous les résultats de KnochenWächter
223
+ for pred in predictions_watcher:
224
+ confidence_color = '#0066cc' if pred['score'] > 0.7 else '#ffa500'
225
+ label_lower = pred['label'].lower()
226
+ # Mettre à jour max_fracture_score seulement pour les fractures
227
+ if pred['score'] >= conf_threshold and 'fracture' in label_lower:
228
+ has_fracture = True
229
+ max_fracture_score = max(max_fracture_score, pred['score'])
230
+ # Afficher tous les résultats
231
+ st.markdown(f"""
232
+ <div class="result-box" style="color: #1a1a1a;">
233
+ <span style="color: {confidence_color}; font-weight: 500;">
234
+ {pred['score']:.1%}
235
+ </span> - {translate_label(pred['label'])}
236
+ </div>
237
+ """, unsafe_allow_html=True)
238
 
239
+ st.markdown("#### 🎓 RöntgenMeister")
240
+ # Afficher tous les résultats de RöntgenMeister
241
+ for pred in predictions_master:
242
+ confidence_color = '#0066cc' if pred['score'] > 0.7 else '#ffa500'
243
+ st.markdown(f"""
244
+ <div class="result-box" style="color: #1a1a1a;">
245
+ <span style="color: {confidence_color}; font-weight: 500;">
246
+ {pred['score']:.1%}
247
+ </span> - {translate_label(pred['label'])}
248
+ </div>
249
+ """, unsafe_allow_html=True)
250
+
251
+ if max_fracture_score > 0:
252
+ st.write("#### 📊 Wahrscheinlichkeit")
253
+ no_fracture_prob = 1 - max_fracture_score
254
+ st.markdown(f"""
255
+ <div class="result-box" style="color: #1a1a1a;">
256
+ Knochenbruch: <strong style="color: #0066cc">{max_fracture_score:.1%}</strong><br>
257
+ Kein Knochenbruch: <strong style="color: #ffa500">{no_fracture_prob:.1%}</strong>
258
+ </div>
259
+ """, unsafe_allow_html=True)
260
+
261
+ with col2:
262
+ predictions = models["KnochenAuge"](image)
263
+ filtered_preds = [p for p in predictions if p['score'] >= conf_threshold]
264
+
265
+ if filtered_preds:
266
+ st.write("#### 🎯 Fraktur Lokalisation")
267
+ result_image = draw_boxes(image, filtered_preds)
268
+ st.image(result_image, use_container_width=True)
269
+ else:
270
+ st.write("#### 🖼️ Röntgenbild")
271
+ st.image(image, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
  if __name__ == "__main__":
274
+ main()