import streamlit as st from PIL import Image import torch from transformers import AutoModelForCausalLM, AutoProcessor import numpy as np import supervision as sv import albumentations as A import cv2 from transformers import AutoConfig import yaml # Set Streamlit page configuration for a wide layout st.set_page_config(layout="wide") # Custom CSS for better layout and mobile responsiveness st.markdown(""" <style> .main { max-width: 1200px; /* Max width for content */ margin: 0 auto; } .block-container { padding-top: 2rem; padding-bottom: 2rem; padding-left: 3rem; padding-right: 3rem; } .title { font-size: 2.5rem; text-align: center; color: #FF6347; } .subheader { font-size: 1.5rem; margin-bottom: 20px; } .btn { font-size: 1.1rem; padding: 10px 20px; background-color: #FF6347; color: white; border-radius: 5px; border: none; cursor: pointer; } .btn:hover { background-color: #FF4500; } .column-spacing { display: flex; justify-content: space-between; } .col-half { width: 48%; } .col-full { width: 100%; } .instructions { padding: 20px; background-color: #f9f9f9; border-radius: 8px; box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1); } </style> """, unsafe_allow_html=True) # Load Model and Processor @st.cache_resource def load_model(): REVISION = 'refs/pr/6' # MODEL_NAME = "RioJune/AD-KD-MICCAI25" MODEL_NAME = 'Anonymous-AC/AD-KD' DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") config_model = AutoConfig.from_pretrained("microsoft/Florence-2-base-ft", trust_remote_code=True) config_model.vision_config.model_type = "davit" model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True, config=config_model).to(DEVICE) BASE_PROCESSOR = "microsoft/Florence-2-base-ft" processor = AutoProcessor.from_pretrained(BASE_PROCESSOR, trust_remote_code=True) processor.image_processor.size = 512 processor.image_processor.crop_size = 512 return model, processor, DEVICE model, processor, DEVICE = load_model() # Load Definitions @st.cache_resource def load_definitions(): vindr_path = 'configs/vindr_definition.yaml' padchest_path = 'configs/padchest_definition.yaml' prompt_path = 'examples/prompt.yaml' with open(vindr_path, 'r') as file: vindr_definitions = yaml.safe_load(file) with open(padchest_path, 'r') as file: padchest_definitions = yaml.safe_load(file) with open(prompt_path, 'r') as file: prompt_definitions = yaml.safe_load(file) return vindr_definitions, padchest_definitions, prompt_definitions vindr_definitions, padchest_definitions, prompt_definitions = load_definitions() dataset_options = {"Vindr": vindr_definitions, "PadChest": padchest_definitions} def load_example_images(): return list(prompt_definitions.keys()) example_images = load_example_images() def apply_transform(image, size_mode=512): pad_resize_transform = A.Compose([ A.LongestMaxSize(max_size=size_mode, interpolation=cv2.INTER_AREA), A.PadIfNeeded(min_height=size_mode, min_width=size_mode, border_mode=cv2.BORDER_CONSTANT, value=(0, 0, 0)), A.Resize(height=512, width=512, interpolation=cv2.INTER_AREA), ]) image_np = np.array(image) transformed = pad_resize_transform(image=image_np) return transformed["image"] # Streamlit UI with Colorful Title and Emojis st.markdown("<h1 class='title'>π©Ί Enhancing Abnormality Grounding for Vision Language Models with Knowledge Descriptions π</h1>", unsafe_allow_html=True) st.markdown( "<p style='text-align: center; font-size: 18px;'>Welcome to a simple demo of our work! π Choose an example or upload your own image to get started! π</p>", unsafe_allow_html=True ) # Display Example Images First st.subheader("π Example Images") selected_example = st.selectbox("Choose an example", example_images) image = Image.open(selected_example).convert("RGB") example_diseases = prompt_definitions.get(selected_example, []) st.write("**Associated Diseases:**", ", ".join(example_diseases)) # Layout for Original Image and Instructions col1, col2 = st.columns([1, 2]) # Left column for original image with col1: st.image(image, caption=f"Original Example Image: {selected_example}", width=400) # Right column for Instructions and Run Inference Button with col2: st.subheader("βοΈ Instructions to Get Started:") st.write(""" - **Run Inference**: Click the "Run Inference on Example" button to process the image and display the results. - **Choose an Example**: π Select an example image from the dataset to view its associated diseases. - **Upload Your Own Image**: π€ Upload an image of your choice to analyze it for diseases. - **Select Dataset**: π Choose between available datasets (Vindr or PadChest) for disease information. - **Select Disease**: π¦ Pick the disease to be analyzed from the list of diseases in the selected dataset. """) st.subheader("β οΈ Warning:") st.write(""" - **π« Please avoid uploading non-frontal chest X-ray images.** Our model has been specifically trained on **frontal chest X-ray images** only. - This demo is intended for **π¬ research purposes only** and should **β not be used for medical diagnoses**. - The modelβs responses may contain **<span style='color:#dc3545; font-weight:bold;'>π€ hallucinations or incorrect information</span>**. - Always consult a **<span style='color:#dc3545; font-weight:bold;'>π¨ββοΈ medical professional</span>** for accurate diagnosis and advice. """, unsafe_allow_html=True) st.markdown("</div>", unsafe_allow_html=True) # Run Inference Button if st.button("Run Inference on Example", key="example"): if image is None: st.error("β Please select an example image first.") else: # Use the selected example's disease and definition for inference disease_choice = example_diseases[0] if example_diseases else "" definition = vindr_definitions.get(disease_choice, padchest_definitions.get(disease_choice, "")) # Generate the prompt for the model det_obj = f"{disease_choice} means {definition}." st.write(f"**Definition:** {definition}") prompt = f"Locate the phrases in the caption: {det_obj}." prompt = f"<CAPTION_TO_PHRASE_GROUNDING>{prompt}" # Prepare the image and input np_image = np.array(image) inputs = processor(text=[prompt], images=[np_image], return_tensors="pt", padding=True).to(DEVICE) with st.spinner("Processing... β³"): outputs = model.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3, output_scores=True, # Make sure we get the scores/logits return_dict_in_generate=True # Ensures you get both sequences and scores in the output ) # Ensure transition_scores is properly extracted transition_scores = model.compute_transition_scores( outputs.sequences, outputs.scores, outputs.beam_indices, normalize_logits=False ) # Get the generated token IDs (ignoring the input tokens part) generated_ids = outputs.sequences generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] # Get input length input_length = inputs.input_ids.shape[1] generated_tokens = outputs.sequences # Calculate output length (number of generated tokens) output_length = np.sum(transition_scores.cpu().numpy() < 0, axis=1) # Get length penalty length_penalty = model.generation_config.length_penalty # Calculate total score for the generated sentence reconstructed_scores = transition_scores.cpu().sum(axis=1) / (output_length**length_penalty) # Convert log-probability to probability (0-1 range) probabilities = np.exp(reconstructed_scores.cpu().numpy()) # Streamlit UI to display the result st.markdown(f"**π― Probability of the Results:** <span style='color:#28a745; font-size:24px; font-weight:bold;'>{probabilities[0] * 100:.2f}%</span>", unsafe_allow_html=True) predictions = processor.post_process_generation(generated_text, task="<CAPTION_TO_PHRASE_GROUNDING>", image_size=np_image.shape[:2]) detection = sv.Detections.from_lmm(sv.LMM.FLORENCE_2, predictions, resolution_wh=np_image.shape[:2]) # Annotate the image with bounding boxes and labels bounding_box_annotator = sv.BoundingBoxAnnotator(color_lookup=sv.ColorLookup.INDEX) label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX) image_with_predictions = bounding_box_annotator.annotate(np_image.copy(), detection) image_with_predictions = label_annotator.annotate(image_with_predictions, detection) annotated_image = Image.fromarray(image_with_predictions.astype(np.uint8)) # Display the original and result images side by side col1, col2 = st.columns([1, 1]) with col1: st.image(image, caption=f"Original Image: {selected_example}", width=400) with col2: st.image(annotated_image, caption="Inference Results πΌοΈ", width=400) # Display the generated text st.write("**Generated Text:**", generated_text) # Upload Image section st.subheader("π€ Upload Your Own Image") col1, col2 = st.columns([1, 1]) with col1: dataset_choice = st.selectbox("Select Dataset π", options=list(dataset_options.keys())) disease_options = list(dataset_options[dataset_choice].keys()) with col2: disease_choice = st.selectbox("Select Disease π¦ ", options=disease_options) uploaded_file = st.file_uploader("Upload an Image", type=["png", "jpg", "jpeg"]) col1, col2 = st.columns([1, 2]) with col1: # Handle file upload if uploaded_file: image = Image.open(uploaded_file).convert("RGB") image = apply_transform(image) # Ensure the uploaded image is transformed correctly st.image(image, caption="Uploaded Image", width=400) # Let user select dataset and disease dynamically disease_choice = disease_choice if disease_choice else example_diseases[0] # Get Definition Priority: Dataset -> User Input definition = vindr_definitions.get(disease_choice, padchest_definitions.get(disease_choice, "")) if not definition: definition = st.text_input("Enter Definition Manually π", value="") with col2: # Instructions and warnings st.subheader("βοΈ Instructions to Get Started:") st.write(""" - **Run Inference**: Click the "Run Inference on Example" button to process the image and display the results. - **Choose an Example**: π Select an example image from the dataset to view its associated diseases. - **Upload Your Own Image**: π€ Upload an image of your choice to analyze it for diseases. - **Select Dataset**: π Choose between available datasets (Vindr or PadChest) for disease information. - **Select Disease**: π¦ Pick the disease to be analyzed from the list of diseases in the selected dataset. """) st.subheader("β οΈ Warning:") st.write(""" - **π« Please avoid uploading non-frontal chest X-ray images.** Our model has been specifically trained on **frontal chest X-ray images** only. - This demo is intended for **π¬ research purposes only** and should **β not be used for medical diagnoses**. - The modelβs responses may contain **<span style='color:#dc3545; font-weight:bold;'>π€ hallucinations or incorrect information</span>**. - Always consult a **<span style='color:#dc3545; font-weight:bold;'>π¨ββοΈ medical professional</span>** for accurate diagnosis and advice. """, unsafe_allow_html=True) # Run inference after upload if st.button("Run Inference πββοΈ"): if image is None: st.error("β Please upload an image or select an example.") else: det_obj = f"{disease_choice} means {definition}." st.write(f"**Definition:** {definition}") # Construct Prompt with Disease Definition prompt = f"Locate the phrases in the caption: {det_obj}." prompt = f"<CAPTION_TO_PHRASE_GROUNDING>{prompt}" np_image = np.array(image) inputs = processor(text=[prompt], images=[np_image], return_tensors="pt", padding=True).to(DEVICE) with st.spinner("Processing... β³"): # generated_ids = model.generate(input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3) # generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] outputs = model.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3, output_scores=True, # Make sure we get the scores/logits return_dict_in_generate=True # Ensures you get both sequences and scores in the output ) transition_scores = model.compute_transition_scores( outputs.sequences, outputs.scores, outputs.beam_indices, normalize_logits=False ) # Get the generated token IDs (ignoring the input tokens part) generated_ids = outputs.sequences generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] # Get input length input_length = inputs.input_ids.shape[1] # Extract generated tokens (ignoring the input tokens) # generated_tokens = outputs.sequences[:, input_length:] generated_tokens = outputs.sequences # Calculate output length (number of generated tokens) output_length = np.sum(transition_scores.cpu().numpy() < 0, axis=1) # Get length penalty length_penalty = model.generation_config.length_penalty # Calculate total score for the generated sentence reconstructed_scores = transition_scores.cpu().sum(axis=1) / (output_length**length_penalty) # Convert log-probability to probability (0-1 range) probabilities = np.exp(reconstructed_scores.cpu().numpy()) # Streamlit UI to display the result # st.write(f"**Probability of the Results (0-1):** {probabilities[0]:.4f}") st.markdown(f"**π― Probability of the Results:** <span style='color:green; font-size:24px; font-weight:bold;'>{probabilities[0] * 100:.2f}%</span>", unsafe_allow_html=True) predictions = processor.post_process_generation(generated_text, task="<CAPTION_TO_PHRASE_GROUNDING>", image_size=np_image.shape[:2]) detection = sv.Detections.from_lmm(sv.LMM.FLORENCE_2, predictions, resolution_wh=np_image.shape[:2]) bounding_box_annotator = sv.BoundingBoxAnnotator(color_lookup=sv.ColorLookup.INDEX) label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX) image_with_predictions = bounding_box_annotator.annotate(np_image.copy(), detection) image_with_predictions = label_annotator.annotate(image_with_predictions, detection) annotated_image = Image.fromarray(image_with_predictions.astype(np.uint8)) # Create two columns to display the original and the results side by side col1, col2 = st.columns([1, 1]) # Left column for original image with col1: st.image(image, caption="Uploaded Image", width=400) # Right column for result image with col2: st.image(annotated_image, caption="Inference Results πΌοΈ", width=400) # Display the generated text st.write("**Generated Text:**", generated_text)