Spaces:
Sleeping
Sleeping
# --- Docstring --- | |
""" | |
Streamlit application for Medical Image Analysis using Google Gemini Vision | |
and Retrieval-Augmented Generation (RAG) with Chroma DB, enhanced for | |
Hugging Face Spaces deployment and improved practices. | |
Features: | |
- Image analysis via Google Gemini Pro Vision. | |
- RAG using Chroma DB with Hugging Face embeddings. | |
- Caching for performance. | |
- Basic logging. | |
- Improved UX and error handling. | |
- Explicit Disclaimer. | |
""" | |
# --- Imports --- | |
import streamlit as st | |
import google.generativeai as genai | |
import chromadb | |
from chromadb.utils import embedding_functions | |
from PIL import Image | |
import io | |
import time | |
import logging | |
from typing import Optional, Dict, List, Any, Tuple | |
# --- Basic Logging Setup --- | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# --- Application Configuration --- | |
# Secrets Management (Prioritize Hugging Face Secrets) | |
try: | |
GOOGLE_API_KEY = st.secrets["GOOGLE_API_KEY"] | |
# HF_TOKEN is optional for many public models, but required for gated/private ones | |
HF_TOKEN = st.secrets.get("HF_TOKEN") # Use .get() for optional token | |
except KeyError as e: | |
err_msg = f"β Missing Secret: {e}. Please add it to your Hugging Face Space secrets." | |
st.error(err_msg) | |
logger.error(err_msg) | |
st.stop() | |
except Exception as e: | |
err_msg = f"β Error loading secrets: {e}" | |
st.error(err_msg) | |
logger.error(err_msg) | |
st.stop() | |
# Gemini Configuration | |
VISION_MODEL_NAME = "gemini-pro-vision" | |
GENERATION_CONFIG = { | |
"temperature": 0.2, | |
"top_p": 0.95, | |
"top_k": 40, | |
"max_output_tokens": 1024, | |
} | |
SAFETY_SETTINGS = [ | |
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
] | |
GEMINI_ANALYSIS_PROMPT = """Analyze this medical image (e.g., pathology slide, diagram, scan). | |
Describe the key visual features relevant to a medical context. | |
Identify potential: | |
- Diseases or conditions indicated | |
- Pathological findings (e.g., cellular morphology, tissue structure, staining patterns) | |
- Visible cell types | |
- Relevant biomarkers (if inferable from staining or morphology) | |
- Anatomical context (if discernible) | |
Be concise and focus primarily on visually evident information. Avoid definitive diagnoses. | |
Structure the output clearly, perhaps using bullet points for findings. | |
""" | |
# Chroma DB Configuration | |
CHROMA_PATH = "chroma_data_hf" # Use a distinct path if needed | |
COLLECTION_NAME = "medical_docs_hf" | |
# IMPORTANT: Choose an appropriate HF embedding model. 'all-mpnet-base-v2' is general purpose. | |
# For better medical results, consider models like: | |
# - 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext' (might need more RAM/compute) | |
# - 'dmis-lab/sapbert-from-pubmedbert-sentencetransformer' | |
# - Other models tagged 'medical' or 'biomedical' on Hugging Face Hub. | |
# Ensure the chosen model is compatible with chromadb's HuggingFaceEmbeddingFunction. | |
EMBEDDING_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2" # <-- REPLACE if possible | |
CHROMA_DISTANCE_METRIC = "cosine" | |
# --- Caching Resource Initialization --- | |
def initialize_gemini_model() -> Optional[genai.GenerativeModel]: | |
"""Initializes and returns the Gemini Generative Model.""" | |
try: | |
genai.configure(api_key=GOOGLE_API_KEY) | |
model = genai.GenerativeModel( | |
model_name=VISION_MODEL_NAME, | |
generation_config=GENERATION_CONFIG, | |
safety_settings=SAFETY_SETTINGS | |
) | |
logger.info(f"Successfully initialized Gemini Model: {VISION_MODEL_NAME}") | |
return model | |
except Exception as e: | |
err_msg = f"β Error initializing Gemini Model ({VISION_MODEL_NAME}): {e}" | |
st.error(err_msg) | |
logger.error(err_msg, exc_info=True) | |
return None | |
def initialize_embedding_function() -> Optional[embedding_functions.HuggingFaceEmbeddingFunction]: | |
"""Initializes and returns the Hugging Face Embedding Function.""" | |
try: | |
# Pass HF_TOKEN if it exists (required for private/gated models) | |
api_key_param = {"api_key": HF_TOKEN} if HF_TOKEN else {} | |
embed_func = embedding_functions.HuggingFaceEmbeddingFunction( | |
api_key=HF_TOKEN, # Pass token here if needed by model | |
model_name=EMBEDDING_MODEL_NAME | |
) | |
logger.info(f"Successfully initialized HuggingFace Embedding Function: {EMBEDDING_MODEL_NAME}") | |
return embed_func | |
except Exception as e: | |
err_msg = f"β Error initializing HuggingFace Embedding Function ({EMBEDDING_MODEL_NAME}): {e}" | |
st.error(err_msg) | |
logger.error(err_msg, exc_info=True) | |
st.info("βΉοΈ Make sure the embedding model name is correct and you have network access. " | |
"If using a private model, ensure HF_TOKEN is set in secrets.") | |
return None | |
def initialize_chroma_collection(_embedding_func: embedding_functions.EmbeddingFunction) -> Optional[chromadb.Collection]: | |
"""Initializes the Chroma DB client and returns the collection.""" | |
if not _embedding_func: | |
st.error("β Cannot initialize Chroma DB without a valid embedding function.") | |
return None | |
try: | |
chroma_client = chromadb.PersistentClient(path=CHROMA_PATH) | |
collection = chroma_client.get_or_create_collection( | |
name=COLLECTION_NAME, | |
embedding_function=_embedding_func, # Pass the initialized function | |
metadata={"hnsw:space": CHROMA_DISTANCE_METRIC} | |
) | |
logger.info(f"Chroma DB collection '{COLLECTION_NAME}' loaded/created at '{CHROMA_PATH}' using {CHROMA_DISTANCE_METRIC}.") | |
return collection | |
except Exception as e: | |
err_msg = f"β Error initializing Chroma DB at '{CHROMA_PATH}': {e}" | |
st.error(err_msg) | |
logger.error(err_msg, exc_info=True) | |
st.info(f"βΉοΈ Ensure the path '{CHROMA_PATH}' is writable.") | |
return None | |
# --- Core Logic Functions (with Caching for Data Operations) --- | |
# Show spinner manually in UI | |
def analyze_image_with_gemini(_gemini_model: genai.GenerativeModel, image_bytes: bytes) -> Tuple[str, bool]: | |
""" | |
Analyzes image bytes with Gemini, returns (analysis_text, is_error). | |
Uses Streamlit's caching based on image_bytes. | |
""" | |
if not _gemini_model: | |
return "Error: Gemini model not initialized.", True | |
try: | |
img = Image.open(io.BytesIO(image_bytes)) | |
response = _gemini_model.generate_content([GEMINI_ANALYSIS_PROMPT, img]) | |
if not response.parts: | |
if response.prompt_feedback and response.prompt_feedback.block_reason: | |
reason = response.prompt_feedback.block_reason | |
msg = f"Analysis blocked by safety settings: {reason}" | |
logger.warning(msg) | |
return msg, True # Indicate block/error state | |
else: | |
msg = "Error: Gemini analysis returned no content (empty or invalid response)." | |
logger.error(msg) | |
return msg, True | |
logger.info("Gemini analysis successful.") | |
return response.text, False # Indicate success | |
except genai.types.BlockedPromptException as e: | |
msg = f"Analysis blocked (prompt issue): {e}" | |
logger.warning(msg) | |
return msg, True | |
except Exception as e: | |
msg = f"Error during Gemini analysis: {e}" | |
logger.error(msg, exc_info=True) | |
return msg, True | |
def query_chroma(_collection: chromadb.Collection, query_text: str, n_results: int = 5) -> Optional[Dict[str, List[Any]]]: | |
"""Queries Chroma DB, returns results dict or None on error.""" | |
if not _collection: | |
return None | |
if not query_text: | |
logger.warning("Attempted to query Chroma with empty text.") | |
return None | |
try: | |
# Placeholder for potential query refinement: | |
# refined_query = refine_query_for_chroma(query_text) # Implement this if needed | |
refined_query = query_text # Using direct analysis text for now | |
results = _collection.query( | |
query_texts=[refined_query], | |
n_results=n_results, | |
include=['documents', 'metadatas', 'distances'] | |
) | |
logger.info(f"Chroma query successful for text snippet: '{query_text[:50]}...'") | |
return results | |
except Exception as e: | |
err_msg = f"Error querying Chroma DB: {e}" | |
st.error(err_msg) # Show error in UI as well | |
logger.error(err_msg, exc_info=True) | |
return None | |
def add_dummy_data_to_chroma(collection: chromadb.Collection, embedding_func: embedding_functions.EmbeddingFunction): | |
"""Adds example medical text snippets to Chroma using the provided embedding function.""" | |
if not collection or not embedding_func: | |
st.error("β Cannot add dummy data: Chroma Collection or Embedding Function not available.") | |
return | |
status = st.status("Adding dummy data to Chroma DB...", expanded=False) | |
try: | |
# --- Dummy Data Definition --- | |
# (Same data as before, but ensure metadata is useful) | |
docs = [ | |
"Figure 1A shows adenocarcinoma of the lung, papillary subtype. Note the glandular structures and nuclear atypia. TTF-1 staining was positive.", | |
"Pathology slide 34B demonstrates high-grade glioma (glioblastoma) with significant necrosis and microvascular proliferation. Ki-67 index was high.", | |
"This diagram illustrates the EGFR signaling pathway and common mutation sites targeted by tyrosine kinase inhibitors in non-small cell lung cancer.", | |
"Micrograph showing chronic gastritis with Helicobacter pylori organisms (visible with special stain, not shown here). Mild intestinal metaplasia is present.", | |
"Slide CJD-Sample-02: Spongiform changes characteristic of prion disease are evident in the cerebral cortex. Gliosis is also noted." | |
] | |
metadatas = [ | |
{"source": "Example Paper 1", "topic": "Lung Cancer Pathology", "entities": "adenocarcinoma, lung cancer, glandular structures, nuclear atypia, papillary subtype, TTF-1", "IMAGE_ID": "fig_1a_adeno_lung.png"}, | |
{"source": "Path Report 789", "topic": "Brain Tumor Pathology", "entities": "high-grade glioma, glioblastoma, necrosis, microvascular proliferation, Ki-67", "IMAGE_ID": "slide_34b_gbm.tiff"}, | |
{"source": "Textbook Chapter 5", "topic": "Molecular Oncology Pathways", "entities": "EGFR, tyrosine kinase inhibitors, non-small cell lung cancer", "IMAGE_ID": "diagram_egfr_pathway.svg"}, | |
{"source": "Path Report 101", "topic": "Gastrointestinal Pathology", "entities": "chronic gastritis, Helicobacter pylori, intestinal metaplasia", "IMAGE_ID": "micrograph_h_pylori_gastritis.jpg"}, | |
{"source": "Case Study CJD", "topic": "Neuropathology", "entities": "prion disease, Spongiform changes, Gliosis, cerebral cortex", "IMAGE_ID": "slide_cjd_sample_02.jpg"} | |
] | |
ids = [f"doc_hf_{int(time.time())}_{i}" for i in range(len(docs))] | |
# Check for existing documents (simple check based on text) | |
status.update(label="Checking for existing dummy documents...") | |
existing_docs = collection.get(where={"$or": [{"document": doc} for doc in docs]}, include=[]) | |
if not existing_docs or not existing_docs.get('ids'): | |
status.update(label=f"Generating embeddings for {len(docs)} documents (may take time)...") | |
# Embeddings are generated implicitly by ChromaDB during .add() | |
# when an embedding_function is configured for the collection. | |
collection.add( | |
documents=docs, | |
metadatas=metadatas, | |
ids=ids | |
) | |
status.update(label=f"β Added {len(docs)} dummy documents.", state="complete") | |
logger.info(f"Added {len(docs)} dummy documents to collection '{COLLECTION_NAME}'.") | |
else: | |
status.update(label="β οΈ Dummy data already exists. No new data added.", state="complete") | |
logger.warning("Dummy data seems to already exist in the collection based on text match.") | |
except Exception as e: | |
err_msg = f"Error adding dummy data to Chroma: {e}" | |
status.update(label=f"β Error: {err_msg}", state="error") | |
logger.error(err_msg, exc_info=True) | |
# --- Initialize Resources --- | |
# These calls use @st.cache_resource, so they run only once per session/resource change. | |
gemini_model = initialize_gemini_model() | |
embedding_func = initialize_embedding_function() | |
collection = initialize_chroma_collection(embedding_func) # Pass embedding func to chroma init | |
# --- Streamlit UI --- | |
st.set_page_config(layout="wide", page_title="Medical Image Analysis & RAG (HF)") | |
st.title("βοΈ Medical Image Analysis & RAG (Hugging Face Enhanced)") | |
# --- DISCLAIMER --- | |
st.warning(""" | |
**β οΈ Disclaimer:** This tool is for demonstration and informational purposes ONLY. | |
It is **NOT** a medical device and should **NOT** be used for actual medical diagnosis, treatment, or decision-making. | |
AI analysis can be imperfect. Always consult with qualified healthcare professionals for any medical concerns. | |
Do **NOT** upload identifiable patient data (PHI). | |
""") | |
st.markdown(""" | |
Upload a medical image. Gemini Vision will analyze it, and related information | |
will be retrieved from a Chroma DB knowledge base using Hugging Face embeddings. | |
""") | |
# Sidebar | |
with st.sidebar: | |
st.header("βοΈ Controls") | |
uploaded_file = st.file_uploader( | |
"Choose an image...", | |
type=["jpg", "jpeg", "png", "tiff", "webp"], | |
help="Upload a medical image file (e.g., pathology, diagram)." | |
) | |
st.divider() | |
if st.button("β Add/Verify Dummy KB Data", help="Adds example text data to Chroma DB if it doesn't exist."): | |
if collection and embedding_func: | |
add_dummy_data_to_chroma(collection, embedding_func) | |
else: | |
st.error("β Cannot add dummy data: Chroma Collection or Embedding Function failed to initialize.") | |
st.divider() | |
st.info(f""" | |
**Setup Info:** | |
- Gemini Model: `{VISION_MODEL_NAME}` | |
- Embedding Model: `{EMBEDDING_MODEL_NAME}` | |
- Chroma Collection: `{COLLECTION_NAME}` (at `{CHROMA_PATH}`) | |
- Distance Metric: `{CHROMA_DISTANCE_METRIC}` | |
""") | |
st.caption(f"Using Google API Key: {'*' * (len(GOOGLE_API_KEY)-4)}{GOOGLE_API_KEY[-4:]}" if GOOGLE_API_KEY else "Not Set") | |
st.caption(f"Using HF Token: {'Provided' if HF_TOKEN else 'Not Provided'}") | |
# Main Display Area | |
col1, col2 = st.columns(2) | |
with col1: | |
st.subheader("πΌοΈ Uploaded Image") | |
if uploaded_file is not None: | |
image_bytes = uploaded_file.getvalue() | |
st.image(image_bytes, caption=f"Uploaded: {uploaded_file.name}", use_column_width=True) | |
else: | |
st.info("Upload an image using the sidebar to begin.") | |
with col2: | |
st.subheader("π¬ Analysis & Retrieval") | |
if uploaded_file is not None and gemini_model and collection: | |
# 1. Analyze Image | |
analysis_text = "" | |
analysis_error = False | |
with st.status("π§ Analyzing image with Gemini Vision...", expanded=True) as status_gemini: | |
# The actual analysis function is cached via @st.cache_data | |
analysis_text, analysis_error = analyze_image_with_gemini(gemini_model, image_bytes) | |
if analysis_error: | |
status_gemini.update(label=f"β οΈ Analysis Failed/Blocked: {analysis_text.split(':')[1].strip() if ':' in analysis_text else 'See details'}", state="error") | |
st.error(f"**Analysis Output:** {analysis_text}") # Show error/block message | |
else: | |
status_gemini.update(label="β Analysis Complete", state="complete") | |
st.markdown("**Gemini Vision Analysis:**") | |
st.markdown(analysis_text) | |
# 2. Query Chroma if Analysis Succeeded | |
if not analysis_error and analysis_text: | |
st.markdown("---") | |
st.subheader("π Related Information (RAG)") | |
with st.status("π Searching knowledge base (Chroma DB)...", expanded=True) as status_chroma: | |
# The actual query function is cached via @st.cache_data | |
chroma_results = query_chroma(collection, analysis_text, n_results=3) | |
if chroma_results and chroma_results.get('documents') and chroma_results['documents'][0]: | |
num_results = len(chroma_results['documents'][0]) | |
status_chroma.update(label=f"β Found {num_results} related entries.", state="complete") | |
for i in range(num_results): | |
doc = chroma_results['documents'][0][i] | |
meta = chroma_results['metadatas'][0][i] | |
dist = chroma_results['distances'][0][i] | |
similarity = 1.0 - dist # For cosine distance | |
expander_title = f"Result {i+1} (Similarity: {similarity:.4f}) | Source: {meta.get('source', 'N/A')}" | |
with st.expander(expander_title): | |
st.markdown("**Retrieved Text:**") | |
st.markdown(f"> {doc}") | |
st.markdown("**Metadata:**") | |
# Display metadata keys/values more nicely | |
for key, value in meta.items(): | |
st.markdown(f"- **{key.replace('_', ' ').title()}:** `{value}`") | |
# Highlight linked image ID | |
if meta.get("IMAGE_ID"): | |
st.info(f"βΉοΈ Associated visual asset ID: `{meta['IMAGE_ID']}`") | |
elif chroma_results is not None: # Query ran, no results | |
status_chroma.update(label="β οΈ No relevant information found.", state="warning") | |
else: # Error occurred during query (already logged and shown via st.error) | |
status_chroma.update(label="β Failed to retrieve results.", state="error") | |
elif not uploaded_file: | |
st.info("Analysis results will appear here once an image is uploaded.") | |
else: | |
st.error("β Analysis cannot proceed. Check if Gemini model or Chroma DB failed to initialize (see sidebar/logs).") | |
st.markdown("---") | |
st.markdown("<div style='text-align: center; font-size: small;'>Powered by Google Gemini, Chroma DB, Hugging Face, and Streamlit</div>", unsafe_allow_html=True) | |