Spaces:
Sleeping
Sleeping
# app.py | |
""" | |
MedSketch AI: Advanced Clinical Diagram Generator | |
A Streamlit application leveraging AI models (GPT-4o, potentially Stable Diffusion) | |
to generate medical diagrams based on user prompts, with options for styling, | |
metadata association, and annotations. | |
""" | |
import os | |
import json | |
import logging | |
from io import BytesIO | |
from typing import List, Dict, Any, Optional, Tuple | |
import streamlit as st | |
from streamlit_drawable_canvas import st_canvas | |
from PIL import Image | |
import openai | |
from openai import OpenAI, OpenAIError # Use modern OpenAI client and error types | |
# βββ Constants βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
APP_TITLE = "MedSketch AI β Advanced Clinical Diagram Generator" | |
DEFAULT_MODEL = "GPT-4o (Vision)" # Updated model name | |
STABLE_DIFFUSION_MODEL = "Stable Diffusion LoRA" # Placeholder name | |
MODEL_OPTIONS = [DEFAULT_MODEL, STABLE_DIFFUSION_MODEL] | |
STYLE_PRESETS = ["Anatomical Diagram", "H&E Histology", "IHC Pathology", "Custom"] | |
DEFAULT_STYLE = "Anatomical Diagram" | |
DEFAULT_STRENGTH = 0.7 | |
IMAGE_SIZE = "1024x1024" | |
CANVAS_SIZE = 512 | |
ANNOTATION_COLOR = "rgba(255, 0, 0, 0.3)" # Red with transparency | |
ANNOTATION_STROKE_WIDTH = 2 | |
SESSION_STATE_ANNOTATIONS = "medsketch_annotations" | |
SESSION_STATE_HISTORY = "medsketch_history" # Store generated images too | |
# βββ Setup & Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
st.set_page_config( | |
page_title=APP_TITLE, | |
layout="wide", | |
initial_sidebar_state="expanded", | |
menu_items={ | |
'About': f"{APP_TITLE} - Generates medical diagrams using AI.", | |
'Get Help': None, # Add a link if you have one | |
'Report a bug': None # Add a link if you have one | |
} | |
) | |
# Initialize OpenAI Client (Best Practice) | |
# Use st.secrets for deployment, fallback to env var for local dev | |
api_key = st.secrets.get("OPENAI_API_KEY", os.getenv("OPENAI_API_KEY")) | |
if not api_key: | |
st.error("π¨ OpenAI API Key not found! Please set it in Streamlit secrets or environment variables.", icon="π¨") | |
st.stop() # Halt execution if no key | |
try: | |
client = OpenAI(api_key=api_key) | |
logger.info("OpenAI client initialized successfully.") | |
except Exception as e: | |
st.error(f"π¨ Failed to initialize OpenAI client: {e}", icon="π¨") | |
logger.exception("OpenAI client initialization failed.") | |
st.stop() | |
# βββ Helper Functions βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def generate_openai_image(prompt: str, style: str, strength: float) -> Image.Image: | |
""" | |
Generates an image using the OpenAI API (GPT-4o). | |
Args: | |
prompt: The user's text prompt. | |
style: The selected style preset. | |
strength: The stylization strength (conceptually used in prompt). | |
Returns: | |
A PIL Image object. | |
Raises: | |
OpenAIError: If the API call fails. | |
IOError: If the image data cannot be processed. | |
""" | |
logger.info(f"Requesting OpenAI image generation for prompt: '{prompt}' with style '{style}'") | |
full_prompt = f"Style: [{style}], Strength: [{strength:.2f}] - Generate the following medical illustration: {prompt}" | |
try: | |
response = client.images.generate( | |
model="dall-e-3", # Or "gpt-4o" if/when available via this endpoint. DALL-E 3 is current standard. | |
prompt=full_prompt, | |
size=IMAGE_SIZE, | |
quality="standard", # or "hd" | |
n=1, | |
response_format="url" # Or "b64_json" to avoid a second request | |
) | |
image_url = response.data[0].url | |
logger.info(f"Image generated successfully, URL: {image_url}") | |
# Fetch the image data from the URL | |
# Note: Using response_format="b64_json" would avoid this extra step | |
import requests # Need to import requests library | |
image_response = requests.get(image_url, timeout=30) # Add timeout | |
image_response.raise_for_status() # Check for HTTP errors | |
img_data = BytesIO(image_response.content) | |
img = Image.open(img_data) | |
return img | |
except OpenAIError as e: | |
logger.error(f"OpenAI API error: {e}") | |
st.error(f"β OpenAI API Error: {e}", icon="β") | |
raise | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Failed to download image from URL {image_url}: {e}") | |
st.error(f"β Network Error: Failed to download image. {e}", icon="β") | |
raise IOError(f"Failed to download image: {e}") from e | |
except Exception as e: | |
logger.exception(f"An unexpected error occurred during OpenAI image generation: {e}") | |
st.error(f"β An unexpected error occurred: {e}", icon="β") | |
raise | |
def generate_sd_image(prompt: str, style: str, strength: float) -> Image.Image: | |
""" | |
Placeholder for generating an image using a Stable Diffusion LoRA model. | |
Replace this with your actual implementation. | |
Args: | |
prompt: The user's text prompt. | |
style: The selected style preset. | |
strength: The stylization strength. | |
Returns: | |
A PIL Image object (dummy implementation). | |
Raises: | |
NotImplementedError: As this is a placeholder. | |
""" | |
logger.warning("Stable Diffusion LoRA model generation is not implemented. Returning placeholder.") | |
st.warning("π§ Stable Diffusion LoRA generation is not yet implemented. Using placeholder.", icon="π§") | |
# --- Placeholder Implementation --- | |
# Replace this with actual SD model call | |
# For now, create a simple dummy image with text | |
img = Image.new('RGB', (CANVAS_SIZE, CANVAS_SIZE), color = (210, 210, 210)) | |
from PIL import ImageDraw | |
d = ImageDraw.Draw(img) | |
d.text((10,10), f"Stable Diffusion Placeholder\nStyle: {style}\nPrompt: {prompt[:50]}...", fill=(0,0,0)) | |
# --- End Placeholder --- | |
# Simulate some processing time | |
import time | |
time.sleep(1) | |
return img | |
# raise NotImplementedError("Stable Diffusion LoRA generation is not yet available.") | |
def display_result(image: Image.Image, prompt: str, index: int, total: int) -> Optional[List[Dict[str, Any]]]: | |
""" | |
Displays a generated image, download button, and annotation canvas. | |
Args: | |
image: The PIL Image to display. | |
prompt: The prompt used to generate the image. | |
index: The index of the current image in a batch. | |
total: The total number of images in the batch. | |
Returns: | |
Annotation data (list of dicts) if annotations were made, otherwise None. | |
""" | |
st.image(image, caption=f"Result {index + 1}/{total}: {prompt}", use_container_width=True) | |
# Prepare image for download | |
buf = BytesIO() | |
image.save(buf, format="PNG") | |
buf.seek(0) | |
st.download_button( | |
label="β¬οΈ Download PNG", | |
data=buf, | |
file_name=f"medsketch_{index+1}_{prompt[:20].replace(' ', '_')}.png", | |
mime="image/png", | |
key=f"download_{index}" | |
) | |
# Annotation Canvas | |
st.markdown("**βοΈ Annotate:**") | |
# Resize image for canvas if needed, maintaining aspect ratio (optional) | |
# For simplicity, we assume the canvas size matches desired annotation size | |
canvas_image = image.copy() | |
canvas_image.thumbnail((CANVAS_SIZE, CANVAS_SIZE)) | |
canvas_result = st_canvas( | |
fill_color=ANNOTATION_COLOR, | |
stroke_width=ANNOTATION_STROKE_WIDTH, | |
background_image=canvas_image, | |
update_streamlit=True, # Update in real-time | |
height=canvas_image.height, | |
width=canvas_image.width, | |
drawing_mode="freedraw", # Or choose other modes like "line", "rect", etc. | |
key=f"canvas_{index}" | |
) | |
if canvas_result.json_data and canvas_result.json_data.get("objects"): | |
return canvas_result.json_data["objects"] | |
return None | |
# βββ Initialize Session State βββββββββββββββββββββββββββββββββββββββββββββββ | |
if SESSION_STATE_ANNOTATIONS not in st.session_state: | |
st.session_state[SESSION_STATE_ANNOTATIONS] = {} # Dict[prompt, List[annotation_objects]] | |
if SESSION_STATE_HISTORY not in st.session_state: | |
st.session_state[SESSION_STATE_HISTORY] = [] # List[Dict[str, Any]] storing generation results | |
# βββ Sidebar: Settings & Metadata βββββββββββββββββββββββββββββββββββββββββββ | |
with st.sidebar: | |
st.header("βοΈ Generation Settings") | |
model_choice = st.selectbox( | |
"Select Model", | |
options=MODEL_OPTIONS, | |
index=MODEL_OPTIONS.index(DEFAULT_MODEL), | |
help="Choose the AI model for image generation." | |
) | |
style_preset = st.radio( | |
"Select Preset Style", | |
options=STYLE_PRESETS, | |
index=STYLE_PRESETS.index(DEFAULT_STYLE), | |
horizontal=True, # More compact layout | |
help="Apply a predefined visual style to the generation." | |
) | |
# Allow custom style input only if "Custom" is selected | |
custom_style_input = "" | |
if style_preset == "Custom": | |
custom_style_input = st.text_input("Enter Custom Style Description:", key="custom_style") | |
final_style = custom_style_input if style_preset == "Custom" else style_preset | |
strength = st.slider( | |
"Stylization Strength", | |
min_value=0.1, | |
max_value=1.0, | |
value=DEFAULT_STRENGTH, | |
step=0.05, | |
help="Controls how strongly the chosen style influences the result (conceptual)." | |
) | |
st.markdown("---") | |
st.header("π Optional Metadata") | |
patient_id = st.text_input("Patient / Case ID", key="patient_id", help="Associate with a specific patient or case.") | |
roi = st.text_input("Region of Interest (ROI)", key="roi", help="Specify the anatomical region shown.") | |
umls_code = st.text_input("UMLS / SNOMED CT Code", key="umls_code", help="Link to relevant medical ontology codes.") | |
# Add a clear history button | |
st.markdown("---") | |
if st.button("β οΈ Clear History & Annotations", help="Removes all generated images and annotations from this session."): | |
st.session_state[SESSION_STATE_ANNOTATIONS] = {} | |
st.session_state[SESSION_STATE_HISTORY] = [] | |
st.rerun() # Refresh the page to reflect cleared state | |
# βββ Main Application Area βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
st.title(APP_TITLE) | |
st.markdown("Generate medical illustrations from text descriptions using AI. Annotate and export your results.") | |
# --- Prompt Input Area --- | |
prompt_input_area = st.container() | |
with prompt_input_area: | |
st.subheader("π Enter Prompt(s)") | |
st.caption("Enter one prompt per line to generate multiple images in a batch.") | |
raw_prompts = st.text_area( | |
"Describe the medical diagram(s) you need:", | |
placeholder=( | |
"Example 1: A sagittal view of the human knee joint, labeling the ACL, PCL, meniscus, femur, and tibia.\n" | |
"Example 2: High-power field H&E stain of lung adenocarcinoma showing glandular formation.\n" | |
"Example 3: Immunohistochemistry (IHC) stain for PD-L1 in tonsil tissue, showing positive staining on immune cells." | |
), | |
height=150, # Slightly larger height | |
label_visibility="collapsed" | |
) | |
prompts: List[str] = [p.strip() for p in raw_prompts.splitlines() if p.strip()] | |
# --- Generation Trigger --- | |
generate_button = st.button( | |
f"π Generate Diagram{'s' if len(prompts) > 1 else ''}", | |
type="primary", | |
disabled=not prompts, # Disable if no prompts | |
use_container_width=True | |
) | |
# --- Generation and Display Area --- | |
results_area = st.container() | |
if generate_button: | |
if not prompts: | |
st.warning("β οΈ Please enter at least one prompt description.", icon="β οΈ") | |
else: | |
logger.info(f"Starting generation for {len(prompts)} prompts using model '{model_choice}'.") | |
num_prompts = len(prompts) | |
max_cols = 3 # Adjust number of columns based on screen width or preference | |
cols = st.columns(min(max_cols, num_prompts)) | |
# Use a progress bar for batch generation | |
progress_bar = st.progress(0, text=f"Initializing generation...") | |
for i, prompt in enumerate(prompts): | |
col_index = i % max_cols | |
with cols[col_index]: | |
st.markdown(f"--- \n**Processing: {i+1}/{num_prompts}**") | |
spinner_msg = f"Generating image {i+1}/{num_prompts} for prompt: \"{prompt[:50]}...\"" | |
with st.spinner(spinner_msg): | |
try: | |
# Select generation function based on model choice | |
if model_choice == DEFAULT_MODEL: | |
generated_image = generate_openai_image(prompt, final_style, strength) | |
elif model_choice == STABLE_DIFFUSION_MODEL: | |
generated_image = generate_sd_image(prompt, final_style, strength) | |
else: | |
st.error(f"Unknown model selected: {model_choice}", icon="β") | |
continue # Skip to next prompt | |
# Display result and get annotations | |
annotations = display_result(generated_image, prompt, i, num_prompts) | |
# Store results and annotations in session state | |
result_data = { | |
"prompt": prompt, | |
"model": model_choice, | |
"style": final_style, | |
"strength": strength, | |
"metadata": { | |
"patient_id": patient_id, | |
"roi": roi, | |
"umls_code": umls_code, | |
}, | |
# Store image data efficiently (e.g., as base64 or keep PIL object if memory allows) | |
# For simplicity here, we might just store prompt and annotations. | |
# Storing images in session state can consume a lot of memory. | |
# Let's store the prompt reference and annotations. | |
"image_ref_index": i # Reference to this generation instance | |
} | |
st.session_state[SESSION_STATE_HISTORY].append(result_data) | |
if annotations: | |
st.session_state[SESSION_STATE_ANNOTATIONS][prompt] = annotations | |
st.success(f"Annotations saved for prompt {i+1}.", icon="β ") | |
except (OpenAIError, IOError, NotImplementedError, Exception) as e: | |
# Errors are logged and displayed by the generation functions | |
st.error(f"Failed to generate image for prompt: '{prompt}'. Error: {e}", icon="π₯") | |
# Optionally add failed attempts to history? | |
st.session_state[SESSION_STATE_HISTORY].append({ | |
"prompt": prompt, "status": "failed", "error": str(e) | |
}) | |
# Update progress bar | |
progress_val = (i + 1) / num_prompts | |
progress_bar.progress(progress_val, text=f"Generated {i+1}/{num_prompts} images...") | |
progress_bar.progress(1.0, text="Batch generation complete!") | |
st.toast(f"Finished generating {num_prompts} image(s)!", icon="π") | |
# Explicitly clear the progress bar after completion | |
# (Streamlit often handles this, but explicit removal can be cleaner) | |
# Consider removing or hiding the progress bar element if needed after completion. | |
# βββ History & Exports Section βββββββββββββββββββββββββββββββββββββββββββββββ | |
history_area = st.container() | |
with history_area: | |
# Use session state history which is more robust | |
if st.session_state[SESSION_STATE_HISTORY]: | |
st.markdown("---") | |
st.subheader("π Session History & Annotations") | |
st.caption("Review generated images (if stored) and their annotations from this session.") | |
# Display stored history (simplified view focusing on annotations) | |
for idx, item in enumerate(st.session_state[SESSION_STATE_HISTORY]): | |
if item.get("status") == "failed": | |
st.warning(f"**Prompt {idx+1} (Failed):** {item['prompt']} \n *Error: {item['error']}*", icon="β οΈ") | |
else: | |
prompt_key = item["prompt"] | |
st.markdown(f"**Prompt {idx+1}:** `{prompt_key}`") | |
st.markdown(f"*Model: {item['model']}, Style: {item['style']}*") | |
# Display metadata if present | |
meta = item.get('metadata', {}) | |
if any(meta.values()): | |
meta_str = ", ".join([f"{k}: {v}" for k, v in meta.items() if v]) | |
st.markdown(f"*Metadata: {meta_str}*") | |
# Check for annotations for this prompt | |
annotations = st.session_state[SESSION_STATE_ANNOTATIONS].get(prompt_key) | |
if annotations: | |
with st.expander(f"View {len(annotations)} Annotation(s)"): | |
st.json(annotations) | |
else: | |
st.caption("_(No annotations made for this item yet)_") | |
st.markdown("---") # Separator between history items | |
# --- Export Annotations --- | |
if st.session_state[SESSION_STATE_ANNOTATIONS]: | |
st.markdown("---") | |
st.subheader("β¬οΈ Export Annotations") | |
try: | |
# Prepare data with metadata included per annotation set | |
export_data = {} | |
# Find corresponding history item to enrich annotation export | |
history_map = {item['prompt']: item for item in st.session_state[SESSION_STATE_HISTORY] if item.get('status') != 'failed'} | |
for prompt, ann_objs in st.session_state[SESSION_STATE_ANNOTATIONS].items(): | |
history_item = history_map.get(prompt) | |
export_data[prompt] = { | |
"annotations": ann_objs, | |
"generation_details": { | |
"model": history_item.get('model'), | |
"style": history_item.get('style'), | |
"strength": history_item.get('strength'), | |
} if history_item else None, | |
"metadata": history_item.get('metadata') if history_item else None | |
} | |
json_data = json.dumps(export_data, indent=2) | |
st.download_button( | |
label="β¬οΈ Export All Annotations (JSON)", | |
data=json_data, | |
file_name="medsketch_session_annotations.json", | |
mime="application/json", | |
help="Download all annotations made during this session, including associated metadata." | |
) | |
except Exception as e: | |
st.error(f"Failed to prepare annotations for download: {e}") | |
logger.error(f"Error preparing JSON export: {e}") | |
elif generate_button: # If generate was clicked but history is empty (e.g., all failed) | |
st.info("No successful generations or annotations in the current session yet.") | |
# Add a footer (optional) | |
st.markdown("---") | |
st.caption("MedSketch AI - Powered by Streamlit and OpenAI") |