mgbam's picture
Update app.py
0d23f5f verified
raw
history blame
15.4 kB
# --- Docstring ---
"""
Streamlit application for Medical Image Analysis using Google Gemini Vision
and Retrieval-Augmented Generation (RAG) with Chroma DB.
Allows users to upload a medical image (pathology slide, diagram, etc.).
1. The image is analyzed by Google's Gemini Pro Vision model to generate a
textual description of key features.
2. This description is then used as a query to a Chroma vector database
(populated with example medical text snippets) to retrieve relevant
information from a simulated knowledge base.
"""
# --- Imports ---
import streamlit as st
import google.generativeai as genai
import chromadb
from chromadb.utils import embedding_functions
from PIL import Image
import io
import time # Used for generating unique IDs for Chroma DB demo data
from typing import Optional, Dict, List, Any # For type hinting
# --- Configuration ---
try:
# Attempt to load the Google API key from Streamlit secrets
GOOGLE_API_KEY = st.secrets["GOOGLE_API_KEY"]
genai.configure(api_key=GOOGLE_API_KEY)
except KeyError:
st.error("❌ GOOGLE_API_KEY not found in Streamlit secrets! Please add it.")
st.stop()
except Exception as e:
st.error(f"❌ Error configuring Google AI SDK: {e}")
st.stop()
# --- Gemini Model Setup ---
# Define the specific Gemini model to use (ensure it's a vision-capable model)
VISION_MODEL_NAME = "gemini-pro-vision"
# Configure generation parameters for the model
# Lower temperature for more deterministic, factual descriptions
GENERATION_CONFIG = {
"temperature": 0.2,
"top_p": 0.95,
"top_k": 40,
"max_output_tokens": 1024,
}
# Configure safety settings (adjust thresholds as needed for medical content)
# Blocking potentially sensitive content might be necessary depending on the images
SAFETY_SETTINGS = [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
]
# Initialize the Gemini Generative Model
try:
gemini_model = genai.GenerativeModel(
model_name=VISION_MODEL_NAME,
generation_config=GENERATION_CONFIG,
safety_settings=SAFETY_SETTINGS
)
st.success(f"βœ… Initialized Gemini Model: {VISION_MODEL_NAME}")
except Exception as e:
st.error(f"❌ Error initializing Gemini Model ({VISION_MODEL_NAME}): {e}")
st.stop()
# --- Chroma DB Setup ---
# Using persistent storage within the Streamlit deployment environment (e.g., HF Space)
# NOTE: Data will be lost if the persistent storage is wiped or the environment resets.
# For production, consider a managed Chroma instance or alternative database.
CHROMA_PATH = "chroma_data"
COLLECTION_NAME = "medical_docs"
# Define the embedding function.
# Using a default Sentence Transformer model (runs locally on CPU).
# IMPORTANT: The embedding model used for querying MUST match the one used
# when initially adding data to the collection.
# For improved performance/relevance on medical text, consider fine-tuned
# medical domain-specific embedding models if available.
embedding_func = embedding_functions.DefaultEmbeddingFunction()
try:
# Initialize Chroma DB client with persistence
chroma_client = chromadb.PersistentClient(path=CHROMA_PATH)
# Get or create the collection, specifying the embedding function and distance metric
# Using cosine distance is common for text similarity tasks.
collection = chroma_client.get_or_create_collection(
name=COLLECTION_NAME,
embedding_function=embedding_func,
metadata={"hnsw:space": "cosine"} # Specify cosine distance metric
)
st.success(f"βœ… Chroma DB collection '{COLLECTION_NAME}' loaded/created at '{CHROMA_PATH}'.")
except Exception as e:
st.error(f"❌ Error initializing Chroma DB at '{CHROMA_PATH}': {e}")
st.info("ℹ️ If this is the first run, the 'chroma_data' directory will be created.")
st.stop()
# --- Helper Functions ---
def analyze_image_with_gemini(image_bytes: bytes) -> str:
"""
Sends image bytes to the Gemini Vision model for analysis and returns
the generated text description.
Args:
image_bytes: The image data as bytes.
Returns:
A string containing the analysis text, or an error/blocked message.
"""
try:
img = Image.open(io.BytesIO(image_bytes))
# Define the prompt for the vision model
prompt = """Analyze this medical image (e.g., pathology slide, diagram, scan).
Describe the key visual features relevant to a medical context.
Identify potential:
- Diseases or conditions indicated
- Pathological findings (e.g., cellular morphology, tissue structure, staining patterns)
- Visible cell types
- Relevant biomarkers (if inferable from staining or morphology)
- Anatomical context (if discernible)
Be concise and focus primarily on visually evident information. Avoid definitive diagnoses.
"""
# Generate content using the model
response = gemini_model.generate_content([prompt, img])
# Check for blocked content or empty response
if not response.parts:
if response.prompt_feedback and response.prompt_feedback.block_reason:
block_reason = response.prompt_feedback.block_reason
st.warning(f"⚠️ Analysis blocked by safety settings: {block_reason}")
return f"Analysis blocked due to safety settings: {block_reason}"
else:
st.error("❌ Gemini analysis returned no content. Response might be empty or invalid.")
return "Error: Gemini analysis failed or returned no content."
# Return the generated text
return response.text
except genai.types.BlockedPromptException as e:
st.error(f"❌ Gemini request blocked due to prompt content: {e}")
return f"Analysis blocked (prompt issue): {e}"
except Exception as e:
st.error(f"❌ An error occurred during Gemini analysis: {e}")
return f"Error analyzing image: {e}"
def query_chroma(query_text: str, n_results: int = 5) -> Optional[Dict[str, List[Any]]]:
"""
Queries the Chroma DB collection with the given text.
Args:
query_text: The text to use for the similarity search.
n_results: The maximum number of results to return.
Returns:
A dictionary containing the query results ('documents', 'metadatas',
'distances'), or None if an error occurs.
"""
try:
results = collection.query(
query_texts=[query_text],
n_results=n_results,
include=['documents', 'metadatas', 'distances'] # Specify fields to include
)
return results
except Exception as e:
st.error(f"❌ Error querying Chroma DB: {e}")
return None
def add_dummy_data_to_chroma():
"""
Adds predefined example medical text snippets and metadata to the Chroma collection.
Checks if documents with the same text already exist before adding.
"""
st.info("Attempting to add dummy data to Chroma DB...")
# --- IMPORTANT ---
# In a real application, this data ingestion process would involve:
# 1. Parsing actual medical documents (research papers, clinical notes, textbooks).
# 2. Extracting relevant text chunks (e.g., using tools like Unstructured).
# 3. Extracting or associating meaningful METADATA (source, patient ID (anonymized),
# image IDs linked to text, extracted entities like diseases/genes).
# 4. Generating embeddings using the SAME embedding function used for querying.
docs = [
"Figure 1A shows adenocarcinoma of the lung, papillary subtype. Note the glandular structures and nuclear atypia. TTF-1 staining was positive.",
"Pathology slide 34B demonstrates high-grade glioma (glioblastoma) with significant necrosis and microvascular proliferation. Ki-67 index was high.",
"This diagram illustrates the EGFR signaling pathway and common mutation sites targeted by tyrosine kinase inhibitors in non-small cell lung cancer.",
"Micrograph showing chronic gastritis with Helicobacter pylori organisms (visible with special stain, not shown here). Mild intestinal metaplasia is present.",
"Slide CJD-Sample-02: Spongiform changes characteristic of prion disease are evident in the cerebral cortex. Gliosis is also noted."
]
metadatas = [
{"source": "Example Paper 1", "topic": "Lung Cancer Pathology", "entities": {"DISEASES": ["adenocarcinoma", "lung cancer"], "PATHOLOGY_FINDINGS": ["glandular structures", "nuclear atypia", "papillary subtype"], "BIOMARKERS": ["TTF-1"]}, "IMAGE_ID": "fig_1a_adeno_lung.png"},
{"source": "Path Report 789", "topic": "Brain Tumor Pathology", "entities": {"DISEASES": ["high-grade glioma", "glioblastoma"], "PATHOLOGY_FINDINGS": ["necrosis", "microvascular proliferation"], "BIOMARKERS": ["Ki-67"]}, "IMAGE_ID": "slide_34b_gbm.tiff"},
{"source": "Textbook Chapter 5", "topic": "Molecular Oncology Pathways", "entities": {"GENES": ["EGFR"], "DRUGS": ["tyrosine kinase inhibitors"], "DISEASES": ["non-small cell lung cancer"]}, "IMAGE_ID": "diagram_egfr_pathway.svg"},
{"source": "Path Report 101", "topic": "Gastrointestinal Pathology", "entities": {"DISEASES": ["chronic gastritis", "Helicobacter pylori infection"], "PATHOLOGY_FINDINGS": ["intestinal metaplasia"]}, "IMAGE_ID": "micrograph_h_pylori_gastritis.jpg"},
{"source": "Case Study CJD", "topic": "Neuropathology", "entities": {"DISEASES": ["prion disease"], "PATHOLOGY_FINDINGS": ["Spongiform changes", "Gliosis"], "ANATOMICAL_LOCATIONS": ["cerebral cortex"]}, "IMAGE_ID": "slide_cjd_sample_02.jpg"}
]
# Generate unique IDs using timestamp + index to minimize collision chance in demo
ids = [f"doc_{int(time.time())}_{i}" for i in range(len(docs))]
try:
# Check if documents with these exact texts already exist to avoid duplicates
existing_docs = collection.get(where={"$or": [{"document": doc} for doc in docs]}, include=[]) # Don't need full data, just check existence
if not existing_docs or not existing_docs.get('ids'):
collection.add(
documents=docs,
metadatas=metadatas,
ids=ids
)
st.success(f"βœ… Added {len(docs)} dummy documents to Chroma collection '{COLLECTION_NAME}'.")
else:
st.warning("⚠️ Dummy data (based on document text) seems to already exist in the collection. No new data added.")
except Exception as e:
st.error(f"❌ Error adding dummy data to Chroma: {e}")
# --- Streamlit UI ---
st.set_page_config(layout="wide", page_title="Medical Image Analysis & RAG")
st.title("βš•οΈ Medical Image Analysis & RAG")
st.markdown("""
Upload a medical image (e.g., pathology slide, diagram).
Google Gemini Vision will analyze it, and Chroma DB will retrieve related text snippets
from a simulated knowledge base based on the analysis.
""")
# Sidebar for Controls
with st.sidebar:
st.header("βš™οΈ Controls")
uploaded_file = st.file_uploader(
"Choose an image...",
type=["jpg", "jpeg", "png", "tiff", "webp"],
help="Upload a medical image file."
)
st.divider() # Visual separator
if st.button("βž• Load Dummy KB Data", help="Add example text data to the Chroma vector database."):
add_dummy_data_to_chroma()
st.divider()
st.info(f"""
ℹ️ **Note:**
- Chroma data is stored in the '{CHROMA_PATH}' folder within the app's environment.
- This data persists across runs but **will be lost** if the hosting environment (e.g., Streamlit Cloud, Hugging Face Space) is reset or its storage is cleared.
- Ensure the Google API Key is set in Streamlit Secrets.
""")
# Main Display Area
col1, col2 = st.columns(2) # Create two columns for layout
with col1:
st.subheader("πŸ–ΌοΈ Uploaded Image")
if uploaded_file is not None:
# Read image bytes from the uploaded file
image_bytes = uploaded_file.getvalue()
# Display the uploaded image
st.image(image_bytes, caption=f"Uploaded: {uploaded_file.name}", use_column_width=True)
else:
st.info("Upload an image using the sidebar to begin.")
with col2:
st.subheader("πŸ”¬ Gemini Vision Analysis")
if uploaded_file is not None:
# Analyze image with Gemini when an image is uploaded
with st.spinner("🧠 Analyzing image with Gemini Vision... This may take a moment."):
analysis_text = analyze_image_with_gemini(image_bytes)
# Display analysis or error message
if analysis_text.startswith("Error:") or analysis_text.startswith("Analysis blocked"):
# Errors/blocks are already logged via st.error/st.warning in the helper function
st.markdown(f"**Analysis Status:** {analysis_text}") # Show status message
else:
st.markdown(analysis_text)
st.markdown("---") # Separator before RAG results
st.subheader("πŸ“š Related Information (RAG via Chroma DB)")
# Query Chroma DB using the Gemini analysis text
with st.spinner("πŸ” Searching knowledge base..."):
chroma_results = query_chroma(analysis_text, n_results=3) # Fetch top 3 results
if chroma_results and chroma_results.get('documents') and chroma_results['documents'][0]:
num_results = len(chroma_results['documents'][0])
st.success(f"βœ… Found {num_results} related entries in the knowledge base:")
for i in range(num_results):
doc = chroma_results['documents'][0][i]
meta = chroma_results['metadatas'][0][i]
dist = chroma_results['distances'][0][i]
expander_title = f"Result {i+1} (Similarity Score: {1-dist:.4f}) - Source: {meta.get('source', 'N/A')}"
with st.expander(expander_title):
st.markdown("**Retrieved Text:**")
st.markdown(f"> {doc}") # Use blockquote for text
st.markdown("**Metadata:**")
st.json(meta) # Display metadata nicely formatted
# Highlight if the retrieved text references another image/asset
if meta.get("IMAGE_ID"):
st.info(f"ℹ️ This text chunk is associated with visual asset: `{meta['IMAGE_ID']}`")
# In a more complex app, you could add logic here to fetch/display this related image if available.
elif chroma_results is not None: # Query ran successfully but found nothing
st.warning("⚠️ No relevant information found in the knowledge base matching the image analysis.")
# Else case (chroma_results is None) implies an error occurred, handled by st.error in query_chroma
else:
st.info("Analysis will appear here once an image is uploaded.")
st.markdown("---")
st.markdown("<div style='text-align: center;'>Powered by Google Gemini, Chroma DB, and Streamlit</div>", unsafe_allow_html=True)