yassonee commited on
Commit
cc165f9
·
verified ·
1 Parent(s): c6944a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -148
app.py CHANGED
@@ -1,111 +1,75 @@
1
  import streamlit as st
2
- from transformers import pipeline, AutoImageProcessor
3
  from PIL import Image, ImageDraw
4
  import torch
5
 
 
6
  st.set_page_config(
7
- page_title="Fraktur Detektion",
8
  layout="wide",
 
9
  initial_sidebar_state="collapsed"
10
  )
11
 
12
- # Configuration du style personnalisé
13
  st.markdown("""
14
  <style>
 
 
 
 
 
 
15
  .stApp {
16
- background-color: transparent !important;
17
  padding: 0 !important;
 
18
  }
19
 
20
- [data-theme="light"] {
21
- --background-color: #ffffff;
22
- --text-color: #1f2937;
23
- --border-color: #e5e7eb;
24
- --button-color: #2563eb;
25
- --button-hover: #1d4ed8;
26
- }
27
-
28
- [data-theme="dark"] {
29
- --background-color: #1f2937;
30
- --text-color: #f3f4f6;
31
- --border-color: #4b5563;
32
- --button-color: #3b82f6;
33
- --button-hover: #2563eb;
34
- }
35
-
36
  .stButton > button {
37
- width: 100% !important;
38
- background-color: var(--button-color) !important;
39
- color: white !important;
40
- border: none !important;
41
- padding: 0.75rem 1.5rem !important;
42
- border-radius: 0.375rem !important;
43
- font-weight: 500 !important;
44
- margin: 1rem 0 !important;
45
- cursor: pointer !important;
46
  }
47
 
48
  .stButton > button:hover {
49
- background-color: var(--button-hover) !important;
50
  }
51
 
 
52
  .block-container {
53
- padding: 0.5rem !important;
54
  max-width: 100% !important;
55
  }
56
 
 
57
  .stImage > img {
58
- max-height: 250px !important;
59
- width: auto !important;
60
- margin: 0 auto !important;
61
- }
62
-
63
- .result-box {
64
- padding: 0.375rem;
65
- border-radius: 0.375rem;
66
- margin: 0.25rem 0;
67
- background: var(--background-color);
68
- border: 1px solid var(--border-color);
69
- color: var(--text-color);
70
- }
71
-
72
- h2, h3, h4 {
73
- margin: 0.5rem 0 !important;
74
- color: var(--text-color) !important;
75
- font-size: 1rem !important;
76
- }
77
-
78
- #MainMenu, footer, header {
79
- display: none !important;
80
  }
81
 
82
- .uploadedFile {
83
- border: 1px dashed var(--border-color);
 
 
84
  border-radius: 0.375rem;
85
- padding: 0.25rem;
86
- }
87
-
88
- div[data-testid="stFileUploader"] {
89
- width: 100%;
90
- }
91
-
92
- /* Cache le message d'upload par défaut */
93
- .uploadedFile small {
94
- display: none !important;
95
  }
96
  </style>
97
  """, unsafe_allow_html=True)
98
 
 
99
  @st.cache_resource
100
  def load_models():
101
  return {
102
  "KnochenAuge": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
103
- "KnochenWächter": pipeline("image-classification",
104
- model="Heem2/bone-fracture-detection-using-xray",
105
- image_processor=AutoImageProcessor.from_pretrained("Heem2/bone-fracture-detection-using-xray")),
106
  "RöntgenMeister": pipeline("image-classification",
107
- model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388",
108
- image_processor=AutoImageProcessor.from_pretrained("nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388"))
109
  }
110
 
111
  def draw_boxes(image, predictions):
@@ -113,98 +77,73 @@ def draw_boxes(image, predictions):
113
  for pred in predictions:
114
  if pred['label'].lower() == 'fracture' and pred['score'] > 0.6:
115
  box = pred['box']
116
- label = f"Fraktur ({pred['score']:.2%})"
117
- color = "#2563eb"
118
 
 
119
  draw.rectangle(
120
  [(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
121
  outline=color,
122
  width=2
123
  )
124
 
 
125
  text_bbox = draw.textbbox((box['xmin'], box['ymin']-15), label)
126
  draw.rectangle(text_bbox, fill=color)
127
  draw.text((box['xmin'], box['ymin']-15), label, fill="white")
 
128
  return image
129
 
130
  def main():
131
- # Initialisation de la session state si nécessaire
132
- if 'models_loaded' not in st.session_state:
133
- st.session_state.models_loaded = False
134
- st.session_state.models = load_models()
135
- st.session_state.models_loaded = True
136
-
137
- # Créer deux colonnes pour l'upload et le bouton
138
- col_upload, col_button = st.columns([3, 1])
139
-
140
- with col_upload:
141
- uploaded_files = st.file_uploader(
142
- "Röntgenbilder hochladen",
143
- type=['png', 'jpg', 'jpeg'],
144
- accept_multiple_files=True,
145
- label_visibility="collapsed"
146
- )
147
-
148
- # Bouton d'analyse dans la deuxième colonne
149
- with col_button:
150
- analyze_clicked = st.button("📋 Analysieren", key="analyze_btn", disabled=not uploaded_files)
151
-
152
- # Analyse des images si le bouton est cliqué
153
- if uploaded_files and analyze_clicked:
154
- col1, col2 = st.columns([1, 1])
155
-
156
- for idx, uploaded_file in enumerate(uploaded_files):
157
- image = Image.open(uploaded_file)
158
-
159
- # Analyse avec KnochenAuge (localisierung)
160
- predictions = st.session_state.models["KnochenAuge"](image)
161
- fractures_found = any(p['label'].lower() == 'fracture' and p['score'] > 0.6 for p in predictions)
162
-
163
- # Afficher uniquement si des fractures sont détectées
164
- if fractures_found:
165
- with col1 if idx % 2 == 0 else col2:
166
- result_image = image.copy()
167
- result_image = draw_boxes(result_image, predictions)
168
- st.image(result_image, caption=f"Bild {idx + 1}", use_column_width=True)
169
-
170
- # Analyse KnochenWächter et RöntgenMeister
171
- pred_wachter = st.session_state.models["KnochenWächter"](image)[0]
172
- pred_meister = st.session_state.models["RöntgenMeister"](image)[0]
173
-
174
- if pred_wachter['score'] > 0.6 or pred_meister['score'] > 0.6:
175
- st.markdown(f"""
176
- <div class='result-box'>
177
- <span style='color: #2563eb'>KnochenWächter:</span> {pred_wachter['score']:.1%}<br>
178
- <span style='color: #2563eb'>RöntgenMeister:</span> {pred_meister['score']:.1%}
179
- </div>
180
- """, unsafe_allow_html=True)
181
-
182
- # Script pour la synchronisation du thème
183
- st.markdown("""
184
- <script>
185
- function updateTheme(isDark) {
186
- document.documentElement.setAttribute('data-theme', isDark ? 'dark' : 'light');
187
- const root = document.documentElement;
188
- if (isDark) {
189
- root.style.setProperty('--background-color', '#1f2937');
190
- root.style.setProperty('--text-color', '#f3f4f6');
191
- root.style.setProperty('--border-color', '#4b5563');
192
- } else {
193
- root.style.setProperty('--background-color', '#ffffff');
194
- root.style.setProperty('--text-color', '#1f2937');
195
- root.style.setProperty('--border-color', '#e5e7eb');
196
- }
197
- }
198
-
199
- window.addEventListener('message', function(e) {
200
- if (e.data.type === 'theme-change') {
201
- updateTheme(e.data.theme === 'dark');
202
- }
203
- });
204
 
205
- updateTheme(window.matchMedia('(prefers-color-scheme: dark)').matches);
206
- </script>
207
- """, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
  if __name__ == "__main__":
210
  main()
 
1
  import streamlit as st
2
+ from transformers import pipeline
3
  from PIL import Image, ImageDraw
4
  import torch
5
 
6
+ # Configuration de la page
7
  st.set_page_config(
 
8
  layout="wide",
9
+ page_title="Fraktur Detektion",
10
  initial_sidebar_state="collapsed"
11
  )
12
 
13
+ # Style personnalisé
14
  st.markdown("""
15
  <style>
16
+ /* Cacher les éléments Streamlit par défaut */
17
+ #MainMenu {visibility: hidden;}
18
+ footer {visibility: hidden;}
19
+ header {visibility: hidden;}
20
+
21
+ /* Style personnalisé pour la page */
22
  .stApp {
23
+ margin: 0;
24
  padding: 0 !important;
25
+ max-width: 100%;
26
  }
27
 
28
+ /* Style pour les boutons */
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  .stButton > button {
30
+ width: 100%;
31
+ background-color: #3b82f6;
32
+ color: white;
33
+ padding: 0.5rem 1rem;
34
+ border-radius: 0.375rem;
35
+ border: none;
36
+ font-weight: 500;
 
 
37
  }
38
 
39
  .stButton > button:hover {
40
+ background-color: #2563eb;
41
  }
42
 
43
+ /* Container principal */
44
  .block-container {
45
+ padding: 1rem !important;
46
  max-width: 100% !important;
47
  }
48
 
49
+ /* Style pour les images */
50
  .stImage > img {
51
+ max-height: 300px;
52
+ width: auto;
53
+ margin: 0 auto;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  }
55
 
56
+ /* Style pour les messages d'erreur */
57
+ .stAlert {
58
+ padding: 1rem;
59
+ margin: 1rem 0;
60
  border-radius: 0.375rem;
 
 
 
 
 
 
 
 
 
 
61
  }
62
  </style>
63
  """, unsafe_allow_html=True)
64
 
65
+ # Cache des modèles
66
  @st.cache_resource
67
  def load_models():
68
  return {
69
  "KnochenAuge": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
70
+ "KnochenWächter": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"),
 
 
71
  "RöntgenMeister": pipeline("image-classification",
72
+ model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
 
73
  }
74
 
75
  def draw_boxes(image, predictions):
 
77
  for pred in predictions:
78
  if pred['label'].lower() == 'fracture' and pred['score'] > 0.6:
79
  box = pred['box']
80
+ label = f"Fraktur ({pred['score']:.1%})"
81
+ color = "#EF4444" # Rouge
82
 
83
+ # Dessiner le rectangle
84
  draw.rectangle(
85
  [(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
86
  outline=color,
87
  width=2
88
  )
89
 
90
+ # Ajouter le label
91
  text_bbox = draw.textbbox((box['xmin'], box['ymin']-15), label)
92
  draw.rectangle(text_bbox, fill=color)
93
  draw.text((box['xmin'], box['ymin']-15), label, fill="white")
94
+
95
  return image
96
 
97
  def main():
98
+ # Chargement des modèles
99
+ models = load_models()
100
+
101
+ # Upload des images
102
+ uploaded_files = st.file_uploader(
103
+ "Röntgenbilder hochladen",
104
+ type=['png', 'jpg', 'jpeg'],
105
+ accept_multiple_files=True,
106
+ label_visibility="collapsed"
107
+ )
108
+
109
+ if uploaded_files:
110
+ # Bouton d'analyse
111
+ if st.button("Bilder analysieren", key="analyze_button"):
112
+ # Création des colonnes pour l'affichage
113
+ cols = st.columns(2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
+ for idx, uploaded_file in enumerate(uploaded_files):
116
+ image = Image.open(uploaded_file)
117
+
118
+ # Analyse avec KnochenAuge
119
+ predictions = models["KnochenAuge"](image)
120
+ fractures_found = any(p['label'].lower() == 'fracture' and p['score'] > 0.6 for p in predictions)
121
+
122
+ if fractures_found:
123
+ with cols[idx % 2]:
124
+ # Créer une copie de l'image pour le dessin
125
+ result_image = image.copy()
126
+ result_image = draw_boxes(result_image, predictions)
127
+ st.image(result_image, use_column_width=True)
128
+
129
+ # Analyses supplémentaires
130
+ pred_wachter = models["KnochenWächter"](image)[0]
131
+ pred_meister = models["RöntgenMeister"](image)[0]
132
+
133
+ if pred_wachter['score'] > 0.6 or pred_meister['score'] > 0.6:
134
+ st.markdown(
135
+ f"""
136
+ <div style='background-color: #1F2937; color: white; padding: 1rem; border-radius: 0.375rem;'>
137
+ <div style='margin-bottom: 0.5rem;'>
138
+ <span style='color: #60A5FA;'>KnochenWächter:</span> {pred_wachter['score']:.1%}
139
+ </div>
140
+ <div>
141
+ <span style='color: #60A5FA;'>RöntgenMeister:</span> {pred_meister['score']:.1%}
142
+ </div>
143
+ </div>
144
+ """,
145
+ unsafe_allow_html=True
146
+ )
147
 
148
  if __name__ == "__main__":
149
  main()