mgbam commited on
Commit
6a39465
Β·
verified Β·
1 Parent(s): 81a11e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -383
app.py CHANGED
@@ -1,388 +1,111 @@
1
- # --- Imports ---
 
2
  import streamlit as st
3
- import google.generativeai as genai
4
- import chromadb
5
- from chromadb.utils import embedding_functions
6
  from PIL import Image
7
- import io
8
- import time
9
- import logging
10
- from typing import Optional, Dict, List, Any, Tuple
11
-
12
- # --- Set Page Config FIRST ---
13
- # This MUST be the first Streamlit command executed in the script.
14
- st.set_page_config(layout="wide", page_title="Medical Image Analysis & RAG (HF/BioBERT)")
15
-
16
- # --- Basic Logging Setup ---
17
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
18
- logger = logging.getLogger(__name__)
19
-
20
- # --- Application Configuration ---
21
- # Secrets Management (Prioritize Hugging Face Secrets)
22
- try:
23
- GOOGLE_API_KEY = st.secrets["GOOGLE_API_KEY"]
24
- HF_TOKEN = st.secrets.get("HF_TOKEN") # Use .get() for optional token
25
- except KeyError as e:
26
- err_msg = f"❌ Missing Secret: {e}. Please add it to your Hugging Face Space secrets."
27
- # Now it's safe to call st.error after set_page_config
28
- st.error(err_msg)
29
- logger.error(err_msg)
30
- st.stop()
31
- except Exception as e:
32
- err_msg = f"❌ Error loading secrets: {e}"
33
- st.error(err_msg)
34
- logger.error(err_msg)
35
- st.stop()
36
-
37
- # Gemini Configuration
38
- VISION_MODEL_NAME = "gemini-pro-vision"
39
- GENERATION_CONFIG = {
40
- "temperature": 0.2,
41
- "top_p": 0.95,
42
- "top_k": 40,
43
- "max_output_tokens": 1024,
44
- }
45
- SAFETY_SETTINGS = [
46
- {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
47
- {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
48
- {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
49
- {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
50
- ]
51
- GEMINI_ANALYSIS_PROMPT = """Analyze this medical image (e.g., pathology slide, diagram, scan).
52
- Describe the key visual features relevant to a medical context.
53
- Identify potential:
54
- - Diseases or conditions indicated
55
- - Pathological findings (e.g., cellular morphology, tissue structure, staining patterns)
56
- - Visible cell types
57
- - Relevant biomarkers (if inferable from staining or morphology)
58
- - Anatomical context (if discernible)
59
-
60
- Be concise and focus primarily on visually evident information. Avoid definitive diagnoses.
61
- Structure the output clearly, perhaps using bullet points for findings.
62
- """
63
-
64
- # Chroma DB Configuration
65
- CHROMA_PATH = "chroma_data_biobert" # Changed path to reflect model change
66
- COLLECTION_NAME = "medical_docs_biobert" # Changed collection name
67
-
68
- # --- Embedding Model Selection ---
69
- # Using BioBERT v1.1 - Good domain knowledge, but potentially suboptimal for *semantic similarity search*.
70
- # Default pooling (likely CLS token) will be used by sentence-transformers.
71
- # Consider models fine-tuned for sentence similarity if retrieval quality is low:
72
- # e.g., 'dmis-lab/sapbert-from-pubmedbert-sentencetransformer'
73
- EMBEDDING_MODEL_NAME = "dmis-lab/biobert-v1.1"
74
- CHROMA_DISTANCE_METRIC = "cosine" # Cosine is generally good for sentence embeddings
75
-
76
- # --- Caching Resource Initialization ---
77
-
78
- @st.cache_resource
79
- def initialize_gemini_model() -> Optional[genai.GenerativeModel]:
80
- """Initializes and returns the Gemini Generative Model."""
81
- try:
82
- genai.configure(api_key=GOOGLE_API_KEY)
83
- model = genai.GenerativeModel(
84
- model_name=VISION_MODEL_NAME,
85
- generation_config=GENERATION_CONFIG,
86
- safety_settings=SAFETY_SETTINGS
87
- )
88
- logger.info(f"Successfully initialized Gemini Model: {VISION_MODEL_NAME}")
89
- return model
90
- except Exception as e:
91
- err_msg = f"❌ Error initializing Gemini Model ({VISION_MODEL_NAME}): {e}"
92
- st.error(err_msg) # Safe to call st.error here now
93
- logger.error(err_msg, exc_info=True)
94
- return None
95
-
96
- @st.cache_resource
97
- def initialize_embedding_function() -> Optional[embedding_functions.HuggingFaceEmbeddingFunction]:
98
- """Initializes and returns the Hugging Face Embedding Function."""
99
- st.info(f"Initializing Embedding Model: {EMBEDDING_MODEL_NAME} (this may take a moment)...")
100
- try:
101
- # Pass HF_TOKEN if it exists (required for private/gated models)
102
- embed_func = embedding_functions.HuggingFaceEmbeddingFunction(
103
- api_key=HF_TOKEN, # Pass token here if needed by model
104
- model_name=EMBEDDING_MODEL_NAME
105
- )
106
- logger.info(f"Successfully initialized HuggingFace Embedding Function: {EMBEDDING_MODEL_NAME}")
107
- st.success(f"Embedding Model {EMBEDDING_MODEL_NAME} initialized.")
108
- return embed_func
109
- except Exception as e:
110
- err_msg = f"❌ Error initializing HuggingFace Embedding Function ({EMBEDDING_MODEL_NAME}): {e}"
111
- st.error(err_msg) # Safe here
112
- logger.error(err_msg, exc_info=True)
113
- st.info("ℹ️ Make sure the embedding model name is correct and you have network access. "
114
- "If using a private model, ensure HF_TOKEN is set in secrets. Check Space logs for details.")
115
- return None
116
-
117
- @st.cache_resource
118
- def initialize_chroma_collection(_embedding_func: embedding_functions.EmbeddingFunction) -> Optional[chromadb.Collection]:
119
- """Initializes the Chroma DB client and returns the collection."""
120
- if not _embedding_func:
121
- st.error("❌ Cannot initialize Chroma DB without a valid embedding function.") # Safe here
122
- return None
123
- st.info(f"Initializing Chroma DB collection '{COLLECTION_NAME}'...")
124
- try:
125
- chroma_client = chromadb.PersistentClient(path=CHROMA_PATH)
126
- collection = chroma_client.get_or_create_collection(
127
- name=COLLECTION_NAME,
128
- embedding_function=_embedding_func, # Pass the initialized function
129
- metadata={"hnsw:space": CHROMA_DISTANCE_METRIC}
130
- )
131
- logger.info(f"Chroma DB collection '{COLLECTION_NAME}' loaded/created at '{CHROMA_PATH}' using {CHROMA_DISTANCE_METRIC}.")
132
- st.success(f"Chroma DB collection '{COLLECTION_NAME}' ready.")
133
- return collection
134
- except Exception as e:
135
- err_msg = f"❌ Error initializing Chroma DB at '{CHROMA_PATH}': {e}"
136
- st.error(err_msg) # Safe here
137
- logger.error(err_msg, exc_info=True)
138
- st.info(f"ℹ️ Ensure the path '{CHROMA_PATH}' is writable. Check Space logs.")
139
- return None
140
-
141
- # --- Core Logic Functions (with Caching for Data Operations) ---
142
-
143
- @st.cache_data(show_spinner=False) # Show spinner manually in UI
144
- def analyze_image_with_gemini(_gemini_model: genai.GenerativeModel, image_bytes: bytes) -> Tuple[str, bool]:
145
- """
146
- Analyzes image bytes with Gemini, returns (analysis_text, is_error).
147
- Uses Streamlit's caching based on image_bytes.
148
- """
149
- if not _gemini_model:
150
- return "Error: Gemini model not initialized.", True
151
-
152
- try:
153
- img = Image.open(io.BytesIO(image_bytes))
154
- response = _gemini_model.generate_content([GEMINI_ANALYSIS_PROMPT, img])
155
-
156
- if not response.parts:
157
- if response.prompt_feedback and response.prompt_feedback.block_reason:
158
- reason = response.prompt_feedback.block_reason
159
- msg = f"Analysis blocked by safety settings: {reason}"
160
- logger.warning(msg)
161
- return msg, True # Indicate block/error state
162
- else:
163
- msg = "Error: Gemini analysis returned no content (empty or invalid response)."
164
- logger.error(msg)
165
- return msg, True
166
- logger.info("Gemini analysis successful.")
167
- return response.text, False # Indicate success
168
-
169
- except genai.types.BlockedPromptException as e:
170
- msg = f"Analysis blocked (prompt issue): {e}"
171
- logger.warning(msg)
172
- return msg, True
173
- except Exception as e:
174
- msg = f"Error during Gemini analysis: {e}"
175
- logger.error(msg, exc_info=True)
176
- return msg, True
177
-
178
- @st.cache_data(show_spinner=False)
179
- def query_chroma(_collection: chromadb.Collection, query_text: str, n_results: int = 5) -> Optional[Dict[str, List[Any]]]:
180
- """Queries Chroma DB, returns results dict or None on error."""
181
- if not _collection:
182
- logger.error("Query attempt failed: Chroma collection is not available.")
183
- return None
184
- if not query_text:
185
- logger.warning("Attempted to query Chroma with empty text.")
186
- return None
187
- try:
188
- refined_query = query_text # Using direct analysis text for now
189
-
190
- results = _collection.query(
191
- query_texts=[refined_query],
192
- n_results=n_results,
193
- include=['documents', 'metadatas', 'distances']
194
- )
195
- logger.info(f"Chroma query successful for text snippet: '{query_text[:50]}...'")
196
- return results
197
- except Exception as e:
198
- # Show error in UI as well
199
- st.error(f"❌ Error querying Chroma DB: {e}", icon="🚨")
200
- logger.error(f"Error querying Chroma DB: {e}", exc_info=True)
201
- return None
202
-
203
- def add_dummy_data_to_chroma(collection: chromadb.Collection, embedding_func: embedding_functions.EmbeddingFunction):
204
- """Adds example medical text snippets to Chroma using the provided embedding function."""
205
- if not collection or not embedding_func:
206
- st.error("❌ Cannot add dummy data: Chroma Collection or Embedding Function not available.")
207
- return
208
-
209
- # Check if dummy data needs adding first to avoid unnecessary processing
210
- docs_to_check = [
211
- "Figure 1A shows adenocarcinoma of the lung, papillary subtype. Note the glandular structures and nuclear atypia. TTF-1 staining was positive."
212
- ] # Only check one doc for speed
213
- try:
214
- existing_check = collection.get(where={"document": docs_to_check[0]}, limit=1, include=[])
215
- if existing_check and existing_check.get('ids'):
216
- st.info("Dummy data seems to already exist. Skipping add.")
217
- logger.info("Skipping dummy data addition as it likely exists.")
218
- return
219
- except Exception as e:
220
- logger.warning(f"Could not efficiently check for existing dummy data: {e}. Proceeding with add attempt.")
221
-
222
-
223
- status = st.status(f"Adding dummy data (using {EMBEDDING_MODEL_NAME})...", expanded=True)
224
- try:
225
- # --- Dummy Data Definition ---
226
- docs = [
227
- "Figure 1A shows adenocarcinoma of the lung, papillary subtype. Note the glandular structures and nuclear atypia. TTF-1 staining was positive.",
228
- "Pathology slide 34B demonstrates high-grade glioma (glioblastoma) with significant necrosis and microvascular proliferation. Ki-67 index was high.",
229
- "This diagram illustrates the EGFR signaling pathway and common mutation sites targeted by tyrosine kinase inhibitors in non-small cell lung cancer.",
230
- "Micrograph showing chronic gastritis with Helicobacter pylori organisms (visible with special stain, not shown here). Mild intestinal metaplasia is present.",
231
- "Slide CJD-Sample-02: Spongiform changes characteristic of prion disease are evident in the cerebral cortex. Gliosis is also noted."
232
- ]
233
- metadatas = [
234
- {"source": "Example Paper 1", "topic": "Lung Cancer Pathology", "entities": "adenocarcinoma, lung cancer, glandular structures, nuclear atypia, papillary subtype, TTF-1", "IMAGE_ID": "fig_1a_adeno_lung.png"},
235
- {"source": "Path Report 789", "topic": "Brain Tumor Pathology", "entities": "high-grade glioma, glioblastoma, necrosis, microvascular proliferation, Ki-67", "IMAGE_ID": "slide_34b_gbm.tiff"},
236
- {"source": "Textbook Chapter 5", "topic": "Molecular Oncology Pathways", "entities": "EGFR, tyrosine kinase inhibitors, non-small cell lung cancer", "IMAGE_ID": "diagram_egfr_pathway.svg"},
237
- {"source": "Path Report 101", "topic": "Gastrointestinal Pathology", "entities": "chronic gastritis, Helicobacter pylori, intestinal metaplasia", "IMAGE_ID": "micrograph_h_pylori_gastritis.jpg"},
238
- {"source": "Case Study CJD", "topic": "Neuropathology", "entities": "prion disease, Spongiform changes, Gliosis, cerebral cortex", "IMAGE_ID": "slide_cjd_sample_02.jpg"}
239
- ]
240
- # Ensure IDs are unique even if run close together
241
- base_id = f"doc_biobert_{int(time.time() * 1000)}"
242
- ids = [f"{base_id}_{i}" for i in range(len(docs))]
243
-
244
- status.update(label=f"Generating embeddings & adding {len(docs)} documents (this uses BioBERT and may take time)...")
245
-
246
- # Embeddings are generated implicitly by ChromaDB during .add()
247
- collection.add(
248
- documents=docs,
249
- metadatas=metadatas,
250
- ids=ids
251
- )
252
- status.update(label=f"βœ… Added {len(docs)} dummy documents.", state="complete", expanded=False)
253
- logger.info(f"Added {len(docs)} dummy documents to collection '{COLLECTION_NAME}'.")
254
-
255
- except Exception as e:
256
- err_msg = f"Error adding dummy data to Chroma: {e}"
257
- status.update(label=f"❌ Error: {err_msg}", state="error", expanded=True)
258
- logger.error(err_msg, exc_info=True)
259
-
260
- # --- Initialize Resources ---
261
- # These calls use @st.cache_resource, run only once unless cleared/changed.
262
- # Order matters if one depends on another (embedding func needed for chroma).
263
- gemini_model = initialize_gemini_model()
264
- embedding_func = initialize_embedding_function()
265
- collection = initialize_chroma_collection(embedding_func) # Pass embedding func
266
-
267
- # --- Streamlit UI ---
268
- # set_page_config() is already called at the top
269
-
270
- st.title("βš•οΈ Medical Image Analysis & RAG (BioBERT Embeddings)")
271
-
272
- # --- DISCLAIMER ---
273
- st.warning("""
274
- **⚠️ Disclaimer:** This tool is for demonstration and informational purposes ONLY.
275
- It is **NOT** a medical device and should **NOT** be used for actual medical diagnosis, treatment, or decision-making.
276
- AI analysis can be imperfect. Always consult with qualified healthcare professionals for any medical concerns.
277
- Do **NOT** upload identifiable patient data (PHI). Analysis quality depends heavily on the chosen embedding model.
278
- """, icon="☣️")
279
-
280
- st.markdown(f"""
281
- Upload a medical image. Gemini Vision will analyze it. Related information
282
- will be retrieved from a Chroma DB knowledge base using **{EMBEDDING_MODEL_NAME}** embeddings.
283
- """)
284
-
285
- # Sidebar
286
- with st.sidebar:
287
- st.header("βš™οΈ Controls")
288
- uploaded_file = st.file_uploader(
289
- "Choose an image...",
290
- type=["jpg", "jpeg", "png", "tiff", "webp"],
291
- help="Upload a medical image file (e.g., pathology, diagram)."
292
  )
 
293
 
294
- st.divider()
295
-
296
- if st.button("βž• Add/Verify Dummy KB Data", help=f"Adds example text data to Chroma DB ({COLLECTION_NAME}) if it doesn't exist."):
297
- if collection and embedding_func:
298
- add_dummy_data_to_chroma(collection, embedding_func)
299
- else:
300
- st.error("❌ Cannot add dummy data: Chroma Collection or Embedding Function failed to initialize.")
301
-
302
- st.divider()
303
-
304
- st.header("ℹ️ System Info")
305
- st.caption(f"**Gemini Model:** `{VISION_MODEL_NAME}`")
306
- st.caption(f"**Embedding Model:** `{EMBEDDING_MODEL_NAME}`")
307
- st.caption(f"**Chroma Collection:** `{COLLECTION_NAME}`")
308
- st.caption(f"**Chroma Path:** `{CHROMA_PATH}`")
309
- st.caption(f"**Distance Metric:** `{CHROMA_DISTANCE_METRIC}`")
310
- st.caption(f"**Google API Key:** {'Set' if GOOGLE_API_KEY else 'Not Set'}")
311
- st.caption(f"**HF Token:** {'Provided' if HF_TOKEN else 'Not Provided'}")
312
-
313
- # Main Display Area
314
- col1, col2 = st.columns(2)
315
-
316
- with col1:
317
- st.subheader("πŸ–ΌοΈ Uploaded Image")
318
- if uploaded_file is not None:
319
- image_bytes = uploaded_file.getvalue()
320
- st.image(image_bytes, caption=f"Uploaded: {uploaded_file.name}", use_column_width=True)
321
  else:
322
- st.info("Upload an image using the sidebar to begin.")
323
-
324
- with col2:
325
- st.subheader("πŸ”¬ Analysis & Retrieval")
326
- if uploaded_file is not None and gemini_model and collection:
327
- # 1. Analyze Image
328
- analysis_text = ""
329
- analysis_error = False
330
- with st.status("🧠 Analyzing image with Gemini Vision...", expanded=True) as status_gemini:
331
- analysis_text, analysis_error = analyze_image_with_gemini(gemini_model, image_bytes)
332
- if analysis_error:
333
- # Shorten the message for status if needed
334
- status_label = f"⚠️ Analysis Failed/Blocked: {analysis_text.split(':')[0]}"
335
- status_gemini.update(label=status_label , state="error")
336
- st.error(f"**Analysis Output:** {analysis_text}", icon="🚨")
337
- else:
338
- status_gemini.update(label="βœ… Analysis Complete", state="complete", expanded=False)
339
- st.markdown("**Gemini Vision Analysis:**")
340
- st.markdown(analysis_text) # Display the successful analysis
341
-
342
- # 2. Query Chroma if Analysis Succeeded
343
- if not analysis_error and analysis_text:
344
- st.markdown("---") # Separator
345
- st.subheader("πŸ“š Related Information (RAG)")
346
- with st.status(f"πŸ” Searching knowledge base (Chroma DB w/ BioBERT)...", expanded=True) as status_chroma:
347
- chroma_results = query_chroma(collection, analysis_text, n_results=3) # Fetch top 3
348
-
349
- if chroma_results and chroma_results.get('documents') and chroma_results['documents'][0]:
350
- num_results = len(chroma_results['documents'][0])
351
- status_chroma.update(label=f"βœ… Found {num_results} related entries.", state="complete", expanded=False)
352
-
353
- for i in range(num_results):
354
- doc = chroma_results['documents'][0][i]
355
- meta = chroma_results['metadatas'][0][i]
356
- dist = chroma_results['distances'][0][i]
357
- # Ensure distance is float before calculation
358
- similarity = 1.0 - float(dist) if dist is not None else 0.0
359
-
360
- expander_title = f"Result {i+1} (Similarity: {similarity:.4f}) | Source: {meta.get('source', 'N/A')}"
361
- with st.expander(expander_title):
362
- st.markdown("**Retrieved Text:**")
363
- st.markdown(f"> {doc}") # Use blockquote
364
- st.markdown("**Metadata:**")
365
- for key, value in meta.items():
366
- st.markdown(f"- **{key.replace('_', ' ').title()}:** `{value}`")
367
- if meta.get("IMAGE_ID"):
368
- st.info(f"ℹ️ Associated visual asset ID: `{meta['IMAGE_ID']}`")
369
-
370
- elif chroma_results is not None: # Query ran, no results
371
- status_chroma.update(label="⚠️ No relevant information found.", state="warning", expanded=False)
372
- st.warning("No relevant documents found in the knowledge base for this analysis.", icon="⚠️")
373
- # Error case is handled by st.error within query_chroma itself
374
- elif chroma_results is None:
375
- status_chroma.update(label="❌ Failed to retrieve results.", state="error", expanded=True)
376
-
377
-
378
- elif not uploaded_file:
379
- st.info("Analysis results will appear here once an image is uploaded.")
380
- else:
381
- # Initialization error occurred earlier, resources might be None
382
- st.error("❌ Analysis cannot proceed. Check if Gemini model or Chroma DB failed to initialize (see sidebar info & Space logs).")
383
-
384
-
385
- st.markdown("---")
386
- st.markdown("<div style='text-align: center; font-size: small;'>Powered by Google Gemini, Chroma DB, Hugging Face, and Streamlit</div>", unsafe_allow_html=True)
387
-
388
-
 
1
+ # app.py
2
+ import os
3
  import streamlit as st
4
+ from streamlit_drawable_canvas import st_canvas
 
 
5
  from PIL import Image
6
+ import openai
7
+ from io import BytesIO
8
+ import json
9
+
10
+ # ─── 1. Configuration & Secrets ─────────────────────────────────────────────
11
+ openai.api_key = st.secrets["OPENAI_API_KEY"] # or os.getenv("OPENAI_API_KEY")
12
+ st.set_page_config(
13
+ page_title="MedSketchβ€―AI",
14
+ layout="wide",
15
+ initial_sidebar_state="expanded",
16
+ )
17
+
18
+ # ─── 2. Sidebar: Settings & Metadata ────────────────────────────────────────
19
+ st.sidebar.header("βš™οΈ Settings")
20
+ model_choice = st.sidebar.selectbox(
21
+ "Model",
22
+ ["GPT-4o (API)", "Stable Diffusion LoRA"],
23
+ index=0
24
+ )
25
+ style_preset = st.sidebar.radio(
26
+ "Preset Style",
27
+ ["Anatomical Diagram", "H&E Histology", "IHC Pathology", "Custom"]
28
+ )
29
+ strength = st.sidebar.slider("Stylization Strength", 0.1, 1.0, 0.7)
30
+
31
+ st.sidebar.markdown("---")
32
+ st.sidebar.header("πŸ“‹ Metadata")
33
+ patient_id = st.sidebar.text_input("Patient / Case ID")
34
+ roi = st.sidebar.text_input("Region of Interest")
35
+ umls_code = st.sidebar.text_input("UMLS / SNOMED CT Code")
36
+
37
+ # ─── 3. Main: Prompt Input & Batch Generation ───────────────────────────────
38
+ st.title("πŸ–ΌοΈ MedSketchβ€―AI – Advanced Clinical Diagram Generator")
39
+
40
+ with st.expander("πŸ“ Enter Prompts (one per line for batch)"):
41
+ raw = st.text_area(
42
+ "Describe what you need:",
43
+ placeholder=(
44
+ "e.g. β€œGenerate a labeled cross‑section of the human heart with chamber names, valves, and flow arrows…”\n"
45
+ "e.g. β€œProduce a stylized H&E stain of liver tissue highlighting portal triads…”"
46
+ ),
47
+ height=120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  )
49
+ prompts = [p.strip() for p in raw.splitlines() if p.strip()]
50
 
51
+ if st.button("πŸš€ Generate"):
52
+ if not prompts:
53
+ st.error("Please enter at least one prompt.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  else:
55
+ cols = st.columns(min(3, len(prompts)))
56
+ for i, prompt in enumerate(prompts):
57
+ with st.spinner(f"Rendering image {i+1}/{len(prompts)}…"):
58
+ if model_choice == "GPT-4o (API)":
59
+ resp = openai.Image.create(
60
+ model="gpt-4o",
61
+ prompt=f"[{style_preset} | strength={strength}] {prompt}",
62
+ size="1024x1024"
63
+ )
64
+ img_data = requests.get(resp["data"][0]["url"]).content
65
+ else:
66
+ # stub for Stable Diffusion LoRA
67
+ img_data = generate_sd_image(prompt, style=style_preset, strength=strength)
68
+ img = Image.open(BytesIO(img_data))
69
+
70
+ # Display + Download
71
+ with cols[i]:
72
+ st.image(img, use_column_width=True, caption=prompt)
73
+ buf = BytesIO()
74
+ img.save(buf, format="PNG")
75
+ st.download_button(
76
+ label="⬇️ Download PNG",
77
+ data=buf.getvalue(),
78
+ file_name=f"medsketch_{i+1}.png",
79
+ mime="image/png"
80
+ )
81
+
82
+ # ─── Annotation Canvas ───────────────────────────
83
+ st.markdown("**✏️ Annotate:**")
84
+ canvas_res = st_canvas(
85
+ fill_color="rgba(255, 0, 0, 0.3)", # annotation color
86
+ stroke_width=2,
87
+ background_image=img,
88
+ update_streamlit=True,
89
+ height=512,
90
+ width=512,
91
+ drawing_mode="freedraw",
92
+ key=f"canvas_{i}"
93
+ )
94
+ # Save annotations
95
+ if canvas_res.json_data:
96
+ ann = canvas_res.json_data["objects"]
97
+ st.session_state.setdefault("annotations", {})[prompt] = ann
98
+
99
+ # ─── 4. History & Exports ───────────────────────────────────────────────────
100
+ if "annotations" in st.session_state:
101
+ st.markdown("---")
102
+ st.subheader("πŸ“š Session History & Annotations")
103
+ for prm, objs in st.session_state["annotations"].items():
104
+ st.markdown(f"**Prompt:** {prm}")
105
+ st.json(objs)
106
+ st.download_button(
107
+ "⬇️ Export All Annotations (JSON)",
108
+ data=json.dumps(st.session_state["annotations"], indent=2),
109
+ file_name="medsketch_annotations.json",
110
+ mime="application/json"
111
+ )