mgbam's picture
Update app.py
ed31030 verified
raw
history blame
18.9 kB
# --- Docstring ---
"""
Streamlit application for Medical Image Analysis using Google Gemini Vision
and Retrieval-Augmented Generation (RAG) with Chroma DB, enhanced for
Hugging Face Spaces deployment and improved practices.
Features:
- Image analysis via Google Gemini Pro Vision.
- RAG using Chroma DB with Hugging Face embeddings.
- Caching for performance.
- Basic logging.
- Improved UX and error handling.
- Explicit Disclaimer.
"""
# --- Imports ---
import streamlit as st
import google.generativeai as genai
import chromadb
from chromadb.utils import embedding_functions
from PIL import Image
import io
import time
import logging
from typing import Optional, Dict, List, Any, Tuple
# --- Basic Logging Setup ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# --- Application Configuration ---
# Secrets Management (Prioritize Hugging Face Secrets)
try:
GOOGLE_API_KEY = st.secrets["GOOGLE_API_KEY"]
# HF_TOKEN is optional for many public models, but required for gated/private ones
HF_TOKEN = st.secrets.get("HF_TOKEN") # Use .get() for optional token
except KeyError as e:
err_msg = f"❌ Missing Secret: {e}. Please add it to your Hugging Face Space secrets."
st.error(err_msg)
logger.error(err_msg)
st.stop()
except Exception as e:
err_msg = f"❌ Error loading secrets: {e}"
st.error(err_msg)
logger.error(err_msg)
st.stop()
# Gemini Configuration
VISION_MODEL_NAME = "gemini-pro-vision"
GENERATION_CONFIG = {
"temperature": 0.2,
"top_p": 0.95,
"top_k": 40,
"max_output_tokens": 1024,
}
SAFETY_SETTINGS = [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
]
GEMINI_ANALYSIS_PROMPT = """Analyze this medical image (e.g., pathology slide, diagram, scan).
Describe the key visual features relevant to a medical context.
Identify potential:
- Diseases or conditions indicated
- Pathological findings (e.g., cellular morphology, tissue structure, staining patterns)
- Visible cell types
- Relevant biomarkers (if inferable from staining or morphology)
- Anatomical context (if discernible)
Be concise and focus primarily on visually evident information. Avoid definitive diagnoses.
Structure the output clearly, perhaps using bullet points for findings.
"""
# Chroma DB Configuration
CHROMA_PATH = "chroma_data_hf" # Use a distinct path if needed
COLLECTION_NAME = "medical_docs_hf"
# IMPORTANT: Choose an appropriate HF embedding model. 'all-mpnet-base-v2' is general purpose.
# For better medical results, consider models like:
# - 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext' (might need more RAM/compute)
# - 'dmis-lab/sapbert-from-pubmedbert-sentencetransformer'
# - Other models tagged 'medical' or 'biomedical' on Hugging Face Hub.
# Ensure the chosen model is compatible with chromadb's HuggingFaceEmbeddingFunction.
EMBEDDING_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2" # <-- REPLACE if possible
CHROMA_DISTANCE_METRIC = "cosine"
# --- Caching Resource Initialization ---
@st.cache_resource
def initialize_gemini_model() -> Optional[genai.GenerativeModel]:
"""Initializes and returns the Gemini Generative Model."""
try:
genai.configure(api_key=GOOGLE_API_KEY)
model = genai.GenerativeModel(
model_name=VISION_MODEL_NAME,
generation_config=GENERATION_CONFIG,
safety_settings=SAFETY_SETTINGS
)
logger.info(f"Successfully initialized Gemini Model: {VISION_MODEL_NAME}")
return model
except Exception as e:
err_msg = f"❌ Error initializing Gemini Model ({VISION_MODEL_NAME}): {e}"
st.error(err_msg)
logger.error(err_msg, exc_info=True)
return None
@st.cache_resource
def initialize_embedding_function() -> Optional[embedding_functions.HuggingFaceEmbeddingFunction]:
"""Initializes and returns the Hugging Face Embedding Function."""
try:
# Pass HF_TOKEN if it exists (required for private/gated models)
api_key_param = {"api_key": HF_TOKEN} if HF_TOKEN else {}
embed_func = embedding_functions.HuggingFaceEmbeddingFunction(
api_key=HF_TOKEN, # Pass token here if needed by model
model_name=EMBEDDING_MODEL_NAME
)
logger.info(f"Successfully initialized HuggingFace Embedding Function: {EMBEDDING_MODEL_NAME}")
return embed_func
except Exception as e:
err_msg = f"❌ Error initializing HuggingFace Embedding Function ({EMBEDDING_MODEL_NAME}): {e}"
st.error(err_msg)
logger.error(err_msg, exc_info=True)
st.info("ℹ️ Make sure the embedding model name is correct and you have network access. "
"If using a private model, ensure HF_TOKEN is set in secrets.")
return None
@st.cache_resource
def initialize_chroma_collection(_embedding_func: embedding_functions.EmbeddingFunction) -> Optional[chromadb.Collection]:
"""Initializes the Chroma DB client and returns the collection."""
if not _embedding_func:
st.error("❌ Cannot initialize Chroma DB without a valid embedding function.")
return None
try:
chroma_client = chromadb.PersistentClient(path=CHROMA_PATH)
collection = chroma_client.get_or_create_collection(
name=COLLECTION_NAME,
embedding_function=_embedding_func, # Pass the initialized function
metadata={"hnsw:space": CHROMA_DISTANCE_METRIC}
)
logger.info(f"Chroma DB collection '{COLLECTION_NAME}' loaded/created at '{CHROMA_PATH}' using {CHROMA_DISTANCE_METRIC}.")
return collection
except Exception as e:
err_msg = f"❌ Error initializing Chroma DB at '{CHROMA_PATH}': {e}"
st.error(err_msg)
logger.error(err_msg, exc_info=True)
st.info(f"ℹ️ Ensure the path '{CHROMA_PATH}' is writable.")
return None
# --- Core Logic Functions (with Caching for Data Operations) ---
@st.cache_data(show_spinner=False) # Show spinner manually in UI
def analyze_image_with_gemini(_gemini_model: genai.GenerativeModel, image_bytes: bytes) -> Tuple[str, bool]:
"""
Analyzes image bytes with Gemini, returns (analysis_text, is_error).
Uses Streamlit's caching based on image_bytes.
"""
if not _gemini_model:
return "Error: Gemini model not initialized.", True
try:
img = Image.open(io.BytesIO(image_bytes))
response = _gemini_model.generate_content([GEMINI_ANALYSIS_PROMPT, img])
if not response.parts:
if response.prompt_feedback and response.prompt_feedback.block_reason:
reason = response.prompt_feedback.block_reason
msg = f"Analysis blocked by safety settings: {reason}"
logger.warning(msg)
return msg, True # Indicate block/error state
else:
msg = "Error: Gemini analysis returned no content (empty or invalid response)."
logger.error(msg)
return msg, True
logger.info("Gemini analysis successful.")
return response.text, False # Indicate success
except genai.types.BlockedPromptException as e:
msg = f"Analysis blocked (prompt issue): {e}"
logger.warning(msg)
return msg, True
except Exception as e:
msg = f"Error during Gemini analysis: {e}"
logger.error(msg, exc_info=True)
return msg, True
@st.cache_data(show_spinner=False)
def query_chroma(_collection: chromadb.Collection, query_text: str, n_results: int = 5) -> Optional[Dict[str, List[Any]]]:
"""Queries Chroma DB, returns results dict or None on error."""
if not _collection:
return None
if not query_text:
logger.warning("Attempted to query Chroma with empty text.")
return None
try:
# Placeholder for potential query refinement:
# refined_query = refine_query_for_chroma(query_text) # Implement this if needed
refined_query = query_text # Using direct analysis text for now
results = _collection.query(
query_texts=[refined_query],
n_results=n_results,
include=['documents', 'metadatas', 'distances']
)
logger.info(f"Chroma query successful for text snippet: '{query_text[:50]}...'")
return results
except Exception as e:
err_msg = f"Error querying Chroma DB: {e}"
st.error(err_msg) # Show error in UI as well
logger.error(err_msg, exc_info=True)
return None
def add_dummy_data_to_chroma(collection: chromadb.Collection, embedding_func: embedding_functions.EmbeddingFunction):
"""Adds example medical text snippets to Chroma using the provided embedding function."""
if not collection or not embedding_func:
st.error("❌ Cannot add dummy data: Chroma Collection or Embedding Function not available.")
return
status = st.status("Adding dummy data to Chroma DB...", expanded=False)
try:
# --- Dummy Data Definition ---
# (Same data as before, but ensure metadata is useful)
docs = [
"Figure 1A shows adenocarcinoma of the lung, papillary subtype. Note the glandular structures and nuclear atypia. TTF-1 staining was positive.",
"Pathology slide 34B demonstrates high-grade glioma (glioblastoma) with significant necrosis and microvascular proliferation. Ki-67 index was high.",
"This diagram illustrates the EGFR signaling pathway and common mutation sites targeted by tyrosine kinase inhibitors in non-small cell lung cancer.",
"Micrograph showing chronic gastritis with Helicobacter pylori organisms (visible with special stain, not shown here). Mild intestinal metaplasia is present.",
"Slide CJD-Sample-02: Spongiform changes characteristic of prion disease are evident in the cerebral cortex. Gliosis is also noted."
]
metadatas = [
{"source": "Example Paper 1", "topic": "Lung Cancer Pathology", "entities": "adenocarcinoma, lung cancer, glandular structures, nuclear atypia, papillary subtype, TTF-1", "IMAGE_ID": "fig_1a_adeno_lung.png"},
{"source": "Path Report 789", "topic": "Brain Tumor Pathology", "entities": "high-grade glioma, glioblastoma, necrosis, microvascular proliferation, Ki-67", "IMAGE_ID": "slide_34b_gbm.tiff"},
{"source": "Textbook Chapter 5", "topic": "Molecular Oncology Pathways", "entities": "EGFR, tyrosine kinase inhibitors, non-small cell lung cancer", "IMAGE_ID": "diagram_egfr_pathway.svg"},
{"source": "Path Report 101", "topic": "Gastrointestinal Pathology", "entities": "chronic gastritis, Helicobacter pylori, intestinal metaplasia", "IMAGE_ID": "micrograph_h_pylori_gastritis.jpg"},
{"source": "Case Study CJD", "topic": "Neuropathology", "entities": "prion disease, Spongiform changes, Gliosis, cerebral cortex", "IMAGE_ID": "slide_cjd_sample_02.jpg"}
]
ids = [f"doc_hf_{int(time.time())}_{i}" for i in range(len(docs))]
# Check for existing documents (simple check based on text)
status.update(label="Checking for existing dummy documents...")
existing_docs = collection.get(where={"$or": [{"document": doc} for doc in docs]}, include=[])
if not existing_docs or not existing_docs.get('ids'):
status.update(label=f"Generating embeddings for {len(docs)} documents (may take time)...")
# Embeddings are generated implicitly by ChromaDB during .add()
# when an embedding_function is configured for the collection.
collection.add(
documents=docs,
metadatas=metadatas,
ids=ids
)
status.update(label=f"βœ… Added {len(docs)} dummy documents.", state="complete")
logger.info(f"Added {len(docs)} dummy documents to collection '{COLLECTION_NAME}'.")
else:
status.update(label="⚠️ Dummy data already exists. No new data added.", state="complete")
logger.warning("Dummy data seems to already exist in the collection based on text match.")
except Exception as e:
err_msg = f"Error adding dummy data to Chroma: {e}"
status.update(label=f"❌ Error: {err_msg}", state="error")
logger.error(err_msg, exc_info=True)
# --- Initialize Resources ---
# These calls use @st.cache_resource, so they run only once per session/resource change.
gemini_model = initialize_gemini_model()
embedding_func = initialize_embedding_function()
collection = initialize_chroma_collection(embedding_func) # Pass embedding func to chroma init
# --- Streamlit UI ---
st.set_page_config(layout="wide", page_title="Medical Image Analysis & RAG (HF)")
st.title("βš•οΈ Medical Image Analysis & RAG (Hugging Face Enhanced)")
# --- DISCLAIMER ---
st.warning("""
**⚠️ Disclaimer:** This tool is for demonstration and informational purposes ONLY.
It is **NOT** a medical device and should **NOT** be used for actual medical diagnosis, treatment, or decision-making.
AI analysis can be imperfect. Always consult with qualified healthcare professionals for any medical concerns.
Do **NOT** upload identifiable patient data (PHI).
""")
st.markdown("""
Upload a medical image. Gemini Vision will analyze it, and related information
will be retrieved from a Chroma DB knowledge base using Hugging Face embeddings.
""")
# Sidebar
with st.sidebar:
st.header("βš™οΈ Controls")
uploaded_file = st.file_uploader(
"Choose an image...",
type=["jpg", "jpeg", "png", "tiff", "webp"],
help="Upload a medical image file (e.g., pathology, diagram)."
)
st.divider()
if st.button("βž• Add/Verify Dummy KB Data", help="Adds example text data to Chroma DB if it doesn't exist."):
if collection and embedding_func:
add_dummy_data_to_chroma(collection, embedding_func)
else:
st.error("❌ Cannot add dummy data: Chroma Collection or Embedding Function failed to initialize.")
st.divider()
st.info(f"""
**Setup Info:**
- Gemini Model: `{VISION_MODEL_NAME}`
- Embedding Model: `{EMBEDDING_MODEL_NAME}`
- Chroma Collection: `{COLLECTION_NAME}` (at `{CHROMA_PATH}`)
- Distance Metric: `{CHROMA_DISTANCE_METRIC}`
""")
st.caption(f"Using Google API Key: {'*' * (len(GOOGLE_API_KEY)-4)}{GOOGLE_API_KEY[-4:]}" if GOOGLE_API_KEY else "Not Set")
st.caption(f"Using HF Token: {'Provided' if HF_TOKEN else 'Not Provided'}")
# Main Display Area
col1, col2 = st.columns(2)
with col1:
st.subheader("πŸ–ΌοΈ Uploaded Image")
if uploaded_file is not None:
image_bytes = uploaded_file.getvalue()
st.image(image_bytes, caption=f"Uploaded: {uploaded_file.name}", use_column_width=True)
else:
st.info("Upload an image using the sidebar to begin.")
with col2:
st.subheader("πŸ”¬ Analysis & Retrieval")
if uploaded_file is not None and gemini_model and collection:
# 1. Analyze Image
analysis_text = ""
analysis_error = False
with st.status("🧠 Analyzing image with Gemini Vision...", expanded=True) as status_gemini:
# The actual analysis function is cached via @st.cache_data
analysis_text, analysis_error = analyze_image_with_gemini(gemini_model, image_bytes)
if analysis_error:
status_gemini.update(label=f"⚠️ Analysis Failed/Blocked: {analysis_text.split(':')[1].strip() if ':' in analysis_text else 'See details'}", state="error")
st.error(f"**Analysis Output:** {analysis_text}") # Show error/block message
else:
status_gemini.update(label="βœ… Analysis Complete", state="complete")
st.markdown("**Gemini Vision Analysis:**")
st.markdown(analysis_text)
# 2. Query Chroma if Analysis Succeeded
if not analysis_error and analysis_text:
st.markdown("---")
st.subheader("πŸ“š Related Information (RAG)")
with st.status("πŸ” Searching knowledge base (Chroma DB)...", expanded=True) as status_chroma:
# The actual query function is cached via @st.cache_data
chroma_results = query_chroma(collection, analysis_text, n_results=3)
if chroma_results and chroma_results.get('documents') and chroma_results['documents'][0]:
num_results = len(chroma_results['documents'][0])
status_chroma.update(label=f"βœ… Found {num_results} related entries.", state="complete")
for i in range(num_results):
doc = chroma_results['documents'][0][i]
meta = chroma_results['metadatas'][0][i]
dist = chroma_results['distances'][0][i]
similarity = 1.0 - dist # For cosine distance
expander_title = f"Result {i+1} (Similarity: {similarity:.4f}) | Source: {meta.get('source', 'N/A')}"
with st.expander(expander_title):
st.markdown("**Retrieved Text:**")
st.markdown(f"> {doc}")
st.markdown("**Metadata:**")
# Display metadata keys/values more nicely
for key, value in meta.items():
st.markdown(f"- **{key.replace('_', ' ').title()}:** `{value}`")
# Highlight linked image ID
if meta.get("IMAGE_ID"):
st.info(f"ℹ️ Associated visual asset ID: `{meta['IMAGE_ID']}`")
elif chroma_results is not None: # Query ran, no results
status_chroma.update(label="⚠️ No relevant information found.", state="warning")
else: # Error occurred during query (already logged and shown via st.error)
status_chroma.update(label="❌ Failed to retrieve results.", state="error")
elif not uploaded_file:
st.info("Analysis results will appear here once an image is uploaded.")
else:
st.error("❌ Analysis cannot proceed. Check if Gemini model or Chroma DB failed to initialize (see sidebar/logs).")
st.markdown("---")
st.markdown("<div style='text-align: center; font-size: small;'>Powered by Google Gemini, Chroma DB, Hugging Face, and Streamlit</div>", unsafe_allow_html=True)