RioJune commited on
Commit
78b1606
Β·
1 Parent(s): 774d5c3
Files changed (2) hide show
  1. .gitignore +2 -1
  2. app.py +97 -24
.gitignore CHANGED
@@ -1 +1,2 @@
1
- app_backup.py
 
 
1
+ app_backup.py
2
+ app_backup_original.py
app.py CHANGED
@@ -69,7 +69,8 @@ st.markdown("""
69
  @st.cache_resource
70
  def load_model():
71
  REVISION = 'refs/pr/6'
72
- MODEL_NAME = "Anonymous-AC/AD-KD"
 
73
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
74
 
75
  config_model = AutoConfig.from_pretrained("microsoft/Florence-2-base-ft", trust_remote_code=True)
@@ -155,10 +156,12 @@ with col2:
155
 
156
  st.subheader("⚠️ Warning:")
157
  st.write("""
158
- - **🚫 Please avoid uploading non-frontal chest X-ray images**. Our model has been specifically trained on **frontal chest X-ray images**.
159
- - This demo is intended for **πŸ”¬ research purposes only** and should **❌ not be used for medical diagnoses**.
160
- - The model’s responses may contain **πŸ€– hallucinations or incorrect information**. Always consult a **πŸ‘¨β€βš•οΈ medical professional** for accurate diagnosis and advice.
161
- """)
 
 
162
 
163
  st.markdown("</div>", unsafe_allow_html=True)
164
 
@@ -182,10 +185,45 @@ if st.button("Run Inference on Example", key="example"):
182
  inputs = processor(text=[prompt], images=[np_image], return_tensors="pt", padding=True).to(DEVICE)
183
 
184
  with st.spinner("Processing... ⏳"):
185
- # Generate the result
186
- generated_ids = model.generate(input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  predictions = processor.post_process_generation(generated_text, task="<CAPTION_TO_PHRASE_GROUNDING>", image_size=np_image.shape[:2])
190
 
191
  detection = sv.Detections.from_lmm(sv.LMM.FLORENCE_2, predictions, resolution_wh=np_image.shape[:2])
@@ -221,18 +259,6 @@ with col2:
221
 
222
  uploaded_file = st.file_uploader("Upload an Image", type=["png", "jpg", "jpeg"])
223
 
224
- # if uploaded_file:
225
- # image = Image.open(uploaded_file).convert("RGB")
226
- # image = apply_transform(image) # Ensure the uploaded image is transformed correctly
227
- # st.image(image, caption="Uploaded Image", width=400)
228
-
229
- # # Let user select dataset and disease dynamically
230
- # disease_choice = disease_choice if disease_choice else example_diseases[0]
231
-
232
- # # Get Definition Priority: Dataset -> User Input
233
- # definition = vindr_definitions.get(disease_choice, padchest_definitions.get(disease_choice, ""))
234
- # if not definition:
235
- # definition = st.text_input("Enter Definition Manually πŸ“", value="")
236
 
237
  col1, col2 = st.columns([1, 2])
238
 
@@ -264,10 +290,11 @@ with col2:
264
 
265
  st.subheader("⚠️ Warning:")
266
  st.write("""
267
- - **🚫 Please avoid uploading non-frontal chest X-ray images**. Our model has been specifically trained on **frontal chest X-ray images**.
268
- - This demo is intended for **πŸ”¬ research purposes only** and should **❌ not be used for medical diagnoses**.
269
- - The model’s responses may contain **πŸ€– hallucinations or incorrect information**. Always consult a **πŸ‘¨β€βš•οΈ medical professional** for accurate diagnosis and advice.
270
- """)
 
271
 
272
  # Run inference after upload
273
  if st.button("Run Inference πŸƒβ€β™‚οΈ"):
@@ -285,9 +312,52 @@ if st.button("Run Inference πŸƒβ€β™‚οΈ"):
285
  inputs = processor(text=[prompt], images=[np_image], return_tensors="pt", padding=True).to(DEVICE)
286
 
287
  with st.spinner("Processing... ⏳"):
288
- generated_ids = model.generate(input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  predictions = processor.post_process_generation(generated_text, task="<CAPTION_TO_PHRASE_GROUNDING>", image_size=np_image.shape[:2])
292
 
293
  detection = sv.Detections.from_lmm(sv.LMM.FLORENCE_2, predictions, resolution_wh=np_image.shape[:2])
@@ -311,3 +381,6 @@ if st.button("Run Inference πŸƒβ€β™‚οΈ"):
311
 
312
  # Display the generated text
313
  st.write("**Generated Text:**", generated_text)
 
 
 
 
69
  @st.cache_resource
70
  def load_model():
71
  REVISION = 'refs/pr/6'
72
+ # MODEL_NAME = "RioJune/AD-KD-MICCAI25"
73
+ MODEL_NAME = '/u/home/lj0/Checkpoints/AD-KD-MICCAI25'
74
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
75
 
76
  config_model = AutoConfig.from_pretrained("microsoft/Florence-2-base-ft", trust_remote_code=True)
 
156
 
157
  st.subheader("⚠️ Warning:")
158
  st.write("""
159
+ - **🚫 Please avoid uploading non-frontal chest X-ray images.** Our model has been specifically trained on **frontal chest X-ray images** only.
160
+ - This demo is intended for **πŸ”¬ research purposes only** and should **❌ not be used for medical diagnoses**.
161
+ - The model’s responses may contain **<span style='color:#dc3545; font-weight:bold;'>πŸ€– hallucinations or incorrect information</span>**.
162
+ - Always consult a **<span style='color:#dc3545; font-weight:bold;'>πŸ‘¨β€βš•οΈ medical professional</span>** for accurate diagnosis and advice.
163
+ """, unsafe_allow_html=True)
164
+
165
 
166
  st.markdown("</div>", unsafe_allow_html=True)
167
 
 
185
  inputs = processor(text=[prompt], images=[np_image], return_tensors="pt", padding=True).to(DEVICE)
186
 
187
  with st.spinner("Processing... ⏳"):
188
+ outputs = model.generate(
189
+ input_ids=inputs["input_ids"],
190
+ pixel_values=inputs["pixel_values"],
191
+ max_new_tokens=1024,
192
+ num_beams=3,
193
+ output_scores=True, # Make sure we get the scores/logits
194
+ return_dict_in_generate=True # Ensures you get both sequences and scores in the output
195
+ )
196
+
197
+
198
+ # Ensure transition_scores is properly extracted
199
+ transition_scores = model.compute_transition_scores(
200
+ outputs.sequences, outputs.scores, outputs.beam_indices, normalize_logits=False
201
+ )
202
+
203
+ # Get the generated token IDs (ignoring the input tokens part)
204
+ generated_ids = outputs.sequences
205
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
206
 
207
+ # Get input length
208
+ input_length = inputs.input_ids.shape[1]
209
+ generated_tokens = outputs.sequences
210
+
211
+ # Calculate output length (number of generated tokens)
212
+ output_length = np.sum(transition_scores.cpu().numpy() < 0, axis=1)
213
+
214
+ # Get length penalty
215
+ length_penalty = model.generation_config.length_penalty
216
+
217
+ # Calculate total score for the generated sentence
218
+ reconstructed_scores = transition_scores.cpu().sum(axis=1) / (output_length**length_penalty)
219
+
220
+ # Convert log-probability to probability (0-1 range)
221
+ probabilities = np.exp(reconstructed_scores.cpu().numpy())
222
+
223
+ # Streamlit UI to display the result
224
+ st.markdown(f"**🎯 Probability of the Results:** <span style='color:#28a745; font-size:24px; font-weight:bold;'>{probabilities[0] * 100:.2f}%</span>", unsafe_allow_html=True)
225
+
226
+
227
  predictions = processor.post_process_generation(generated_text, task="<CAPTION_TO_PHRASE_GROUNDING>", image_size=np_image.shape[:2])
228
 
229
  detection = sv.Detections.from_lmm(sv.LMM.FLORENCE_2, predictions, resolution_wh=np_image.shape[:2])
 
259
 
260
  uploaded_file = st.file_uploader("Upload an Image", type=["png", "jpg", "jpeg"])
261
 
 
 
 
 
 
 
 
 
 
 
 
 
262
 
263
  col1, col2 = st.columns([1, 2])
264
 
 
290
 
291
  st.subheader("⚠️ Warning:")
292
  st.write("""
293
+ - **🚫 Please avoid uploading non-frontal chest X-ray images.** Our model has been specifically trained on **frontal chest X-ray images** only.
294
+ - This demo is intended for **πŸ”¬ research purposes only** and should **❌ not be used for medical diagnoses**.
295
+ - The model’s responses may contain **<span style='color:#dc3545; font-weight:bold;'>πŸ€– hallucinations or incorrect information</span>**.
296
+ - Always consult a **<span style='color:#dc3545; font-weight:bold;'>πŸ‘¨β€βš•οΈ medical professional</span>** for accurate diagnosis and advice.
297
+ """, unsafe_allow_html=True)
298
 
299
  # Run inference after upload
300
  if st.button("Run Inference πŸƒβ€β™‚οΈ"):
 
312
  inputs = processor(text=[prompt], images=[np_image], return_tensors="pt", padding=True).to(DEVICE)
313
 
314
  with st.spinner("Processing... ⏳"):
315
+ # generated_ids = model.generate(input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3)
316
+ # generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
317
+
318
+ outputs = model.generate(
319
+ input_ids=inputs["input_ids"],
320
+ pixel_values=inputs["pixel_values"],
321
+ max_new_tokens=1024,
322
+ num_beams=3,
323
+ output_scores=True, # Make sure we get the scores/logits
324
+ return_dict_in_generate=True # Ensures you get both sequences and scores in the output
325
+ )
326
+
327
+ transition_scores = model.compute_transition_scores(
328
+ outputs.sequences, outputs.scores, outputs.beam_indices, normalize_logits=False
329
+ )
330
+
331
+ # Get the generated token IDs (ignoring the input tokens part)
332
+ generated_ids = outputs.sequences
333
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
334
 
335
+ # Get input length
336
+ input_length = inputs.input_ids.shape[1]
337
+
338
+ # Extract generated tokens (ignoring the input tokens)
339
+ # generated_tokens = outputs.sequences[:, input_length:]
340
+ generated_tokens = outputs.sequences
341
+
342
+ # Calculate output length (number of generated tokens)
343
+ output_length = np.sum(transition_scores.cpu().numpy() < 0, axis=1)
344
+
345
+ # Get length penalty
346
+ length_penalty = model.generation_config.length_penalty
347
+
348
+ # Calculate total score for the generated sentence
349
+ reconstructed_scores = transition_scores.cpu().sum(axis=1) / (output_length**length_penalty)
350
+
351
+ # Convert log-probability to probability (0-1 range)
352
+ probabilities = np.exp(reconstructed_scores.cpu().numpy())
353
+
354
+ # Streamlit UI to display the result
355
+
356
+ # st.write(f"**Probability of the Results (0-1):** {probabilities[0]:.4f}")
357
+ st.markdown(f"**🎯 Probability of the Results:** <span style='color:green; font-size:24px; font-weight:bold;'>{probabilities[0] * 100:.2f}%</span>", unsafe_allow_html=True)
358
+
359
+
360
+
361
  predictions = processor.post_process_generation(generated_text, task="<CAPTION_TO_PHRASE_GROUNDING>", image_size=np_image.shape[:2])
362
 
363
  detection = sv.Detections.from_lmm(sv.LMM.FLORENCE_2, predictions, resolution_wh=np_image.shape[:2])
 
381
 
382
  # Display the generated text
383
  st.write("**Generated Text:**", generated_text)
384
+
385
+
386
+