import streamlit as st from PIL import Image from transformers import pipeline import io import torch # Import PyTorch # --- Configuration --- # Specify the model MODEL_NAME = "microsoft/maira-2" # --- Model Loading (using pipeline) --- @st.cache_resource # Cache the pipeline for performance def load_pipeline(): """Loads the VQA pipeline.""" try: # Explicitly set device if CUDA is available, otherwise use CPU device = 0 if torch.cuda.is_available() else -1 # Use torch.cuda vqa_pipeline = pipeline("visual-question-answering", model=MODEL_NAME, device=device) # Add device return vqa_pipeline except Exception as e: st.error(f"Error loading pipeline: {e}") return None # --- Image Preprocessing (Keep as bytes) --- def prepare_image(image): """Prepares the PIL Image object for the pipeline (handles RGBA).""" image_bytes = io.BytesIO() if image.mode == "RGBA": image = image.convert("RGB") image.save(image_bytes, format="JPEG") return image_bytes.getvalue() # Return bytes directly # --- Streamlit App --- def main(): st.title("Chest X-ray Analysis with Maira-2 (Transformers Pipeline)") st.write("Upload a chest X-ray image. This app uses the Maira-2 model via the Transformers library.") vqa_pipeline = load_pipeline() if vqa_pipeline is None: st.warning("Pipeline not loaded. Predictions will not be available.") return uploaded_file = st.file_uploader("Choose a chest X-ray image (JPG, PNG)", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: image = Image.open(uploaded_file) st.image(image, caption="Uploaded Image", use_column_width=True) with st.spinner("Analyzing image with Maira-2..."): image_data = prepare_image(image) try: results = vqa_pipeline( image=image_data, # Pass the image bytes question="Analyze this chest X-ray image and provide detailed findings. Include any abnormalities, their locations, and potential diagnoses. Be as specific as possible.", ) if results: # Handle results (list of dicts) if isinstance(results, list) and len(results) > 0: best_answer = max(results, key=lambda x: x.get('score', 0)) if 'answer' in best_answer: st.subheader("Findings:") st.write(best_answer['answer']) else: st.warning("Could not find 'answer' in results.") else: st.warning("Unexpected result format.") except Exception as e: st.error(f"An error occurred during analysis: {e}") else: st.write("Please upload an image.") st.write("---") st.write("Disclaimer: For informational purposes only. Not medical advice.") if __name__ == "__main__": main()