# --- Docstring --- """ Streamlit application for Medical Image Analysis using Google Gemini Vision and Retrieval-Augmented Generation (RAG) with Chroma DB, enhanced for Hugging Face Spaces deployment and improved practices. Features: - Image analysis via Google Gemini Pro Vision. - RAG using Chroma DB with Hugging Face embeddings. - Caching for performance. - Basic logging. - Improved UX and error handling. - Explicit Disclaimer. """ # --- Imports --- import streamlit as st import google.generativeai as genai import chromadb from chromadb.utils import embedding_functions from PIL import Image import io import time import logging from typing import Optional, Dict, List, Any, Tuple # --- Basic Logging Setup --- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # --- Application Configuration --- # Secrets Management (Prioritize Hugging Face Secrets) try: GOOGLE_API_KEY = st.secrets["GOOGLE_API_KEY"] # HF_TOKEN is optional for many public models, but required for gated/private ones HF_TOKEN = st.secrets.get("HF_TOKEN") # Use .get() for optional token except KeyError as e: err_msg = f"❌ Missing Secret: {e}. Please add it to your Hugging Face Space secrets." st.error(err_msg) logger.error(err_msg) st.stop() except Exception as e: err_msg = f"❌ Error loading secrets: {e}" st.error(err_msg) logger.error(err_msg) st.stop() # Gemini Configuration VISION_MODEL_NAME = "gemini-pro-vision" GENERATION_CONFIG = { "temperature": 0.2, "top_p": 0.95, "top_k": 40, "max_output_tokens": 1024, } SAFETY_SETTINGS = [ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, ] GEMINI_ANALYSIS_PROMPT = """Analyze this medical image (e.g., pathology slide, diagram, scan). Describe the key visual features relevant to a medical context. Identify potential: - Diseases or conditions indicated - Pathological findings (e.g., cellular morphology, tissue structure, staining patterns) - Visible cell types - Relevant biomarkers (if inferable from staining or morphology) - Anatomical context (if discernible) Be concise and focus primarily on visually evident information. Avoid definitive diagnoses. Structure the output clearly, perhaps using bullet points for findings. """ # Chroma DB Configuration CHROMA_PATH = "chroma_data_hf" # Use a distinct path if needed COLLECTION_NAME = "medical_docs_hf" # IMPORTANT: Choose an appropriate HF embedding model. 'all-mpnet-base-v2' is general purpose. # For better medical results, consider models like: # - 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext' (might need more RAM/compute) # - 'dmis-lab/sapbert-from-pubmedbert-sentencetransformer' # - Other models tagged 'medical' or 'biomedical' on Hugging Face Hub. # Ensure the chosen model is compatible with chromadb's HuggingFaceEmbeddingFunction. EMBEDDING_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2" # <-- REPLACE if possible CHROMA_DISTANCE_METRIC = "cosine" # --- Caching Resource Initialization --- @st.cache_resource def initialize_gemini_model() -> Optional[genai.GenerativeModel]: """Initializes and returns the Gemini Generative Model.""" try: genai.configure(api_key=GOOGLE_API_KEY) model = genai.GenerativeModel( model_name=VISION_MODEL_NAME, generation_config=GENERATION_CONFIG, safety_settings=SAFETY_SETTINGS ) logger.info(f"Successfully initialized Gemini Model: {VISION_MODEL_NAME}") return model except Exception as e: err_msg = f"❌ Error initializing Gemini Model ({VISION_MODEL_NAME}): {e}" st.error(err_msg) logger.error(err_msg, exc_info=True) return None @st.cache_resource def initialize_embedding_function() -> Optional[embedding_functions.HuggingFaceEmbeddingFunction]: """Initializes and returns the Hugging Face Embedding Function.""" try: # Pass HF_TOKEN if it exists (required for private/gated models) api_key_param = {"api_key": HF_TOKEN} if HF_TOKEN else {} embed_func = embedding_functions.HuggingFaceEmbeddingFunction( api_key=HF_TOKEN, # Pass token here if needed by model model_name=EMBEDDING_MODEL_NAME ) logger.info(f"Successfully initialized HuggingFace Embedding Function: {EMBEDDING_MODEL_NAME}") return embed_func except Exception as e: err_msg = f"❌ Error initializing HuggingFace Embedding Function ({EMBEDDING_MODEL_NAME}): {e}" st.error(err_msg) logger.error(err_msg, exc_info=True) st.info("ℹ️ Make sure the embedding model name is correct and you have network access. " "If using a private model, ensure HF_TOKEN is set in secrets.") return None @st.cache_resource def initialize_chroma_collection(_embedding_func: embedding_functions.EmbeddingFunction) -> Optional[chromadb.Collection]: """Initializes the Chroma DB client and returns the collection.""" if not _embedding_func: st.error("❌ Cannot initialize Chroma DB without a valid embedding function.") return None try: chroma_client = chromadb.PersistentClient(path=CHROMA_PATH) collection = chroma_client.get_or_create_collection( name=COLLECTION_NAME, embedding_function=_embedding_func, # Pass the initialized function metadata={"hnsw:space": CHROMA_DISTANCE_METRIC} ) logger.info(f"Chroma DB collection '{COLLECTION_NAME}' loaded/created at '{CHROMA_PATH}' using {CHROMA_DISTANCE_METRIC}.") return collection except Exception as e: err_msg = f"❌ Error initializing Chroma DB at '{CHROMA_PATH}': {e}" st.error(err_msg) logger.error(err_msg, exc_info=True) st.info(f"ℹ️ Ensure the path '{CHROMA_PATH}' is writable.") return None # --- Core Logic Functions (with Caching for Data Operations) --- @st.cache_data(show_spinner=False) # Show spinner manually in UI def analyze_image_with_gemini(_gemini_model: genai.GenerativeModel, image_bytes: bytes) -> Tuple[str, bool]: """ Analyzes image bytes with Gemini, returns (analysis_text, is_error). Uses Streamlit's caching based on image_bytes. """ if not _gemini_model: return "Error: Gemini model not initialized.", True try: img = Image.open(io.BytesIO(image_bytes)) response = _gemini_model.generate_content([GEMINI_ANALYSIS_PROMPT, img]) if not response.parts: if response.prompt_feedback and response.prompt_feedback.block_reason: reason = response.prompt_feedback.block_reason msg = f"Analysis blocked by safety settings: {reason}" logger.warning(msg) return msg, True # Indicate block/error state else: msg = "Error: Gemini analysis returned no content (empty or invalid response)." logger.error(msg) return msg, True logger.info("Gemini analysis successful.") return response.text, False # Indicate success except genai.types.BlockedPromptException as e: msg = f"Analysis blocked (prompt issue): {e}" logger.warning(msg) return msg, True except Exception as e: msg = f"Error during Gemini analysis: {e}" logger.error(msg, exc_info=True) return msg, True @st.cache_data(show_spinner=False) def query_chroma(_collection: chromadb.Collection, query_text: str, n_results: int = 5) -> Optional[Dict[str, List[Any]]]: """Queries Chroma DB, returns results dict or None on error.""" if not _collection: return None if not query_text: logger.warning("Attempted to query Chroma with empty text.") return None try: # Placeholder for potential query refinement: # refined_query = refine_query_for_chroma(query_text) # Implement this if needed refined_query = query_text # Using direct analysis text for now results = _collection.query( query_texts=[refined_query], n_results=n_results, include=['documents', 'metadatas', 'distances'] ) logger.info(f"Chroma query successful for text snippet: '{query_text[:50]}...'") return results except Exception as e: err_msg = f"Error querying Chroma DB: {e}" st.error(err_msg) # Show error in UI as well logger.error(err_msg, exc_info=True) return None def add_dummy_data_to_chroma(collection: chromadb.Collection, embedding_func: embedding_functions.EmbeddingFunction): """Adds example medical text snippets to Chroma using the provided embedding function.""" if not collection or not embedding_func: st.error("❌ Cannot add dummy data: Chroma Collection or Embedding Function not available.") return status = st.status("Adding dummy data to Chroma DB...", expanded=False) try: # --- Dummy Data Definition --- # (Same data as before, but ensure metadata is useful) docs = [ "Figure 1A shows adenocarcinoma of the lung, papillary subtype. Note the glandular structures and nuclear atypia. TTF-1 staining was positive.", "Pathology slide 34B demonstrates high-grade glioma (glioblastoma) with significant necrosis and microvascular proliferation. Ki-67 index was high.", "This diagram illustrates the EGFR signaling pathway and common mutation sites targeted by tyrosine kinase inhibitors in non-small cell lung cancer.", "Micrograph showing chronic gastritis with Helicobacter pylori organisms (visible with special stain, not shown here). Mild intestinal metaplasia is present.", "Slide CJD-Sample-02: Spongiform changes characteristic of prion disease are evident in the cerebral cortex. Gliosis is also noted." ] metadatas = [ {"source": "Example Paper 1", "topic": "Lung Cancer Pathology", "entities": "adenocarcinoma, lung cancer, glandular structures, nuclear atypia, papillary subtype, TTF-1", "IMAGE_ID": "fig_1a_adeno_lung.png"}, {"source": "Path Report 789", "topic": "Brain Tumor Pathology", "entities": "high-grade glioma, glioblastoma, necrosis, microvascular proliferation, Ki-67", "IMAGE_ID": "slide_34b_gbm.tiff"}, {"source": "Textbook Chapter 5", "topic": "Molecular Oncology Pathways", "entities": "EGFR, tyrosine kinase inhibitors, non-small cell lung cancer", "IMAGE_ID": "diagram_egfr_pathway.svg"}, {"source": "Path Report 101", "topic": "Gastrointestinal Pathology", "entities": "chronic gastritis, Helicobacter pylori, intestinal metaplasia", "IMAGE_ID": "micrograph_h_pylori_gastritis.jpg"}, {"source": "Case Study CJD", "topic": "Neuropathology", "entities": "prion disease, Spongiform changes, Gliosis, cerebral cortex", "IMAGE_ID": "slide_cjd_sample_02.jpg"} ] ids = [f"doc_hf_{int(time.time())}_{i}" for i in range(len(docs))] # Check for existing documents (simple check based on text) status.update(label="Checking for existing dummy documents...") existing_docs = collection.get(where={"$or": [{"document": doc} for doc in docs]}, include=[]) if not existing_docs or not existing_docs.get('ids'): status.update(label=f"Generating embeddings for {len(docs)} documents (may take time)...") # Embeddings are generated implicitly by ChromaDB during .add() # when an embedding_function is configured for the collection. collection.add( documents=docs, metadatas=metadatas, ids=ids ) status.update(label=f"✅ Added {len(docs)} dummy documents.", state="complete") logger.info(f"Added {len(docs)} dummy documents to collection '{COLLECTION_NAME}'.") else: status.update(label="⚠️ Dummy data already exists. No new data added.", state="complete") logger.warning("Dummy data seems to already exist in the collection based on text match.") except Exception as e: err_msg = f"Error adding dummy data to Chroma: {e}" status.update(label=f"❌ Error: {err_msg}", state="error") logger.error(err_msg, exc_info=True) # --- Initialize Resources --- # These calls use @st.cache_resource, so they run only once per session/resource change. gemini_model = initialize_gemini_model() embedding_func = initialize_embedding_function() collection = initialize_chroma_collection(embedding_func) # Pass embedding func to chroma init # --- Streamlit UI --- st.set_page_config(layout="wide", page_title="Medical Image Analysis & RAG (HF)") st.title("⚕️ Medical Image Analysis & RAG (Hugging Face Enhanced)") # --- DISCLAIMER --- st.warning(""" **⚠️ Disclaimer:** This tool is for demonstration and informational purposes ONLY. It is **NOT** a medical device and should **NOT** be used for actual medical diagnosis, treatment, or decision-making. AI analysis can be imperfect. Always consult with qualified healthcare professionals for any medical concerns. Do **NOT** upload identifiable patient data (PHI). """) st.markdown(""" Upload a medical image. Gemini Vision will analyze it, and related information will be retrieved from a Chroma DB knowledge base using Hugging Face embeddings. """) # Sidebar with st.sidebar: st.header("⚙️ Controls") uploaded_file = st.file_uploader( "Choose an image...", type=["jpg", "jpeg", "png", "tiff", "webp"], help="Upload a medical image file (e.g., pathology, diagram)." ) st.divider() if st.button("➕ Add/Verify Dummy KB Data", help="Adds example text data to Chroma DB if it doesn't exist."): if collection and embedding_func: add_dummy_data_to_chroma(collection, embedding_func) else: st.error("❌ Cannot add dummy data: Chroma Collection or Embedding Function failed to initialize.") st.divider() st.info(f""" **Setup Info:** - Gemini Model: `{VISION_MODEL_NAME}` - Embedding Model: `{EMBEDDING_MODEL_NAME}` - Chroma Collection: `{COLLECTION_NAME}` (at `{CHROMA_PATH}`) - Distance Metric: `{CHROMA_DISTANCE_METRIC}` """) st.caption(f"Using Google API Key: {'*' * (len(GOOGLE_API_KEY)-4)}{GOOGLE_API_KEY[-4:]}" if GOOGLE_API_KEY else "Not Set") st.caption(f"Using HF Token: {'Provided' if HF_TOKEN else 'Not Provided'}") # Main Display Area col1, col2 = st.columns(2) with col1: st.subheader("🖼️ Uploaded Image") if uploaded_file is not None: image_bytes = uploaded_file.getvalue() st.image(image_bytes, caption=f"Uploaded: {uploaded_file.name}", use_column_width=True) else: st.info("Upload an image using the sidebar to begin.") with col2: st.subheader("🔬 Analysis & Retrieval") if uploaded_file is not None and gemini_model and collection: # 1. Analyze Image analysis_text = "" analysis_error = False with st.status("🧠 Analyzing image with Gemini Vision...", expanded=True) as status_gemini: # The actual analysis function is cached via @st.cache_data analysis_text, analysis_error = analyze_image_with_gemini(gemini_model, image_bytes) if analysis_error: status_gemini.update(label=f"⚠️ Analysis Failed/Blocked: {analysis_text.split(':')[1].strip() if ':' in analysis_text else 'See details'}", state="error") st.error(f"**Analysis Output:** {analysis_text}") # Show error/block message else: status_gemini.update(label="✅ Analysis Complete", state="complete") st.markdown("**Gemini Vision Analysis:**") st.markdown(analysis_text) # 2. Query Chroma if Analysis Succeeded if not analysis_error and analysis_text: st.markdown("---") st.subheader("📚 Related Information (RAG)") with st.status("🔍 Searching knowledge base (Chroma DB)...", expanded=True) as status_chroma: # The actual query function is cached via @st.cache_data chroma_results = query_chroma(collection, analysis_text, n_results=3) if chroma_results and chroma_results.get('documents') and chroma_results['documents'][0]: num_results = len(chroma_results['documents'][0]) status_chroma.update(label=f"✅ Found {num_results} related entries.", state="complete") for i in range(num_results): doc = chroma_results['documents'][0][i] meta = chroma_results['metadatas'][0][i] dist = chroma_results['distances'][0][i] similarity = 1.0 - dist # For cosine distance expander_title = f"Result {i+1} (Similarity: {similarity:.4f}) | Source: {meta.get('source', 'N/A')}" with st.expander(expander_title): st.markdown("**Retrieved Text:**") st.markdown(f"> {doc}") st.markdown("**Metadata:**") # Display metadata keys/values more nicely for key, value in meta.items(): st.markdown(f"- **{key.replace('_', ' ').title()}:** `{value}`") # Highlight linked image ID if meta.get("IMAGE_ID"): st.info(f"ℹ️ Associated visual asset ID: `{meta['IMAGE_ID']}`") elif chroma_results is not None: # Query ran, no results status_chroma.update(label="⚠️ No relevant information found.", state="warning") else: # Error occurred during query (already logged and shown via st.error) status_chroma.update(label="❌ Failed to retrieve results.", state="error") elif not uploaded_file: st.info("Analysis results will appear here once an image is uploaded.") else: st.error("❌ Analysis cannot proceed. Check if Gemini model or Chroma DB failed to initialize (see sidebar/logs).") st.markdown("---") st.markdown("
Powered by Google Gemini, Chroma DB, Hugging Face, and Streamlit
", unsafe_allow_html=True)