mgbam commited on
Commit
943c488
Β·
verified Β·
1 Parent(s): e286f72

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +416 -85
app.py CHANGED
@@ -1,111 +1,442 @@
1
  # app.py
 
 
 
 
 
 
 
 
2
  import os
 
 
 
 
 
3
  import streamlit as st
4
  from streamlit_drawable_canvas import st_canvas
5
  from PIL import Image
6
  import openai
7
- from io import BytesIO
8
- import json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # ─── 1. Configuration & Secrets ─────────────────────────────────────────────
11
- openai.api_key = st.secrets["OPENAI_API_KEY"] # or os.getenv("OPENAI_API_KEY")
12
  st.set_page_config(
13
- page_title="MedSketchβ€―AI",
14
  layout="wide",
15
  initial_sidebar_state="expanded",
 
 
 
 
 
16
  )
17
 
18
- # ─── 2. Sidebar: Settings & Metadata ────────────────────────────────────────
19
- st.sidebar.header("βš™οΈ Settings")
20
- model_choice = st.sidebar.selectbox(
21
- "Model",
22
- ["GPT-4o (API)", "Stable Diffusion LoRA"],
23
- index=0
24
- )
25
- style_preset = st.sidebar.radio(
26
- "Preset Style",
27
- ["Anatomical Diagram", "H&E Histology", "IHC Pathology", "Custom"]
28
- )
29
- strength = st.sidebar.slider("Stylization Strength", 0.1, 1.0, 0.7)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- st.sidebar.markdown("---")
32
- st.sidebar.header("πŸ“‹ Metadata")
33
- patient_id = st.sidebar.text_input("Patient / Case ID")
34
- roi = st.sidebar.text_input("Region of Interest")
35
- umls_code = st.sidebar.text_input("UMLS / SNOMED CT Code")
36
 
37
- # ─── 3. Main: Prompt Input & Batch Generation ───────────────────────────────
38
- st.title("πŸ–ΌοΈ MedSketchβ€―AI – Advanced Clinical Diagram Generator")
39
 
40
- with st.expander("πŸ“ Enter Prompts (one per line for batch)"):
41
- raw = st.text_area(
42
- "Describe what you need:",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  placeholder=(
44
- "e.g. β€œGenerate a labeled cross‑section of the human heart with chamber names, valves, and flow arrows…”\n"
45
- "e.g. β€œProduce a stylized H&E stain of liver tissue highlighting portal triads…”"
 
46
  ),
47
- height=120
 
48
  )
49
- prompts = [p.strip() for p in raw.splitlines() if p.strip()]
50
 
51
- if st.button("πŸš€ Generate"):
 
 
 
 
 
 
 
 
 
 
52
  if not prompts:
53
- st.error("Please enter at least one prompt.")
54
  else:
55
- cols = st.columns(min(3, len(prompts)))
 
 
 
 
 
 
 
56
  for i, prompt in enumerate(prompts):
57
- with st.spinner(f"Rendering image {i+1}/{len(prompts)}…"):
58
- if model_choice == "GPT-4o (API)":
59
- resp = openai.Image.create(
60
- model="gpt-4o",
61
- prompt=f"[{style_preset} | strength={strength}] {prompt}",
62
- size="1024x1024"
63
- )
64
- img_data = requests.get(resp["data"][0]["url"]).content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  else:
66
- # stub for Stable Diffusion LoRA
67
- img_data = generate_sd_image(prompt, style=style_preset, strength=strength)
68
- img = Image.open(BytesIO(img_data))
69
-
70
- # Display + Download
71
- with cols[i]:
72
- st.image(img, use_column_width=True, caption=prompt)
73
- buf = BytesIO()
74
- img.save(buf, format="PNG")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  st.download_button(
76
- label="⬇️ Download PNG",
77
- data=buf.getvalue(),
78
- file_name=f"medsketch_{i+1}.png",
79
- mime="image/png"
80
- )
81
-
82
- # ─── Annotation Canvas ───────────────────────────
83
- st.markdown("**✏️ Annotate:**")
84
- canvas_res = st_canvas(
85
- fill_color="rgba(255, 0, 0, 0.3)", # annotation color
86
- stroke_width=2,
87
- background_image=img,
88
- update_streamlit=True,
89
- height=512,
90
- width=512,
91
- drawing_mode="freedraw",
92
- key=f"canvas_{i}"
93
  )
94
- # Save annotations
95
- if canvas_res.json_data:
96
- ann = canvas_res.json_data["objects"]
97
- st.session_state.setdefault("annotations", {})[prompt] = ann
98
 
99
- # ─── 4. History & Exports ───────────────────────────────────────────────────
100
- if "annotations" in st.session_state:
101
- st.markdown("---")
102
- st.subheader("πŸ“š Session History & Annotations")
103
- for prm, objs in st.session_state["annotations"].items():
104
- st.markdown(f"**Prompt:** {prm}")
105
- st.json(objs)
106
- st.download_button(
107
- "⬇️ Export All Annotations (JSON)",
108
- data=json.dumps(st.session_state["annotations"], indent=2),
109
- file_name="medsketch_annotations.json",
110
- mime="application/json"
111
- )
 
1
  # app.py
2
+ """
3
+ MedSketch AI: Advanced Clinical Diagram Generator
4
+
5
+ A Streamlit application leveraging AI models (GPT-4o, potentially Stable Diffusion)
6
+ to generate medical diagrams based on user prompts, with options for styling,
7
+ metadata association, and annotations.
8
+ """
9
+
10
  import os
11
+ import json
12
+ import logging
13
+ from io import BytesIO
14
+ from typing import List, Dict, Any, Optional, Tuple
15
+
16
  import streamlit as st
17
  from streamlit_drawable_canvas import st_canvas
18
  from PIL import Image
19
  import openai
20
+ from openai import OpenAI, OpenAIError # Use modern OpenAI client and error types
21
+
22
+ # ─── Constants ───────────────────────────────────────────────────────────────
23
+
24
+ APP_TITLE = "MedSketch AI – Advanced Clinical Diagram Generator"
25
+ DEFAULT_MODEL = "GPT-4o (Vision)" # Updated model name
26
+ STABLE_DIFFUSION_MODEL = "Stable Diffusion LoRA" # Placeholder name
27
+ MODEL_OPTIONS = [DEFAULT_MODEL, STABLE_DIFFUSION_MODEL]
28
+ STYLE_PRESETS = ["Anatomical Diagram", "H&E Histology", "IHC Pathology", "Custom"]
29
+ DEFAULT_STYLE = "Anatomical Diagram"
30
+ DEFAULT_STRENGTH = 0.7
31
+ IMAGE_SIZE = "1024x1024"
32
+ CANVAS_SIZE = 512
33
+ ANNOTATION_COLOR = "rgba(255, 0, 0, 0.3)" # Red with transparency
34
+ ANNOTATION_STROKE_WIDTH = 2
35
+ SESSION_STATE_ANNOTATIONS = "medsketch_annotations"
36
+ SESSION_STATE_HISTORY = "medsketch_history" # Store generated images too
37
+
38
+ # ─── Setup & Configuration ────────────────────────────────────────────────────
39
+
40
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
41
+ logger = logging.getLogger(__name__)
42
 
 
 
43
  st.set_page_config(
44
+ page_title=APP_TITLE,
45
  layout="wide",
46
  initial_sidebar_state="expanded",
47
+ menu_items={
48
+ 'About': f"{APP_TITLE} - Generates medical diagrams using AI.",
49
+ 'Get Help': None, # Add a link if you have one
50
+ 'Report a bug': None # Add a link if you have one
51
+ }
52
  )
53
 
54
+ # Initialize OpenAI Client (Best Practice)
55
+ # Use st.secrets for deployment, fallback to env var for local dev
56
+ api_key = st.secrets.get("OPENAI_API_KEY", os.getenv("OPENAI_API_KEY"))
57
+ if not api_key:
58
+ st.error("🚨 OpenAI API Key not found! Please set it in Streamlit secrets or environment variables.", icon="🚨")
59
+ st.stop() # Halt execution if no key
60
+
61
+ try:
62
+ client = OpenAI(api_key=api_key)
63
+ logger.info("OpenAI client initialized successfully.")
64
+ except Exception as e:
65
+ st.error(f"🚨 Failed to initialize OpenAI client: {e}", icon="🚨")
66
+ logger.exception("OpenAI client initialization failed.")
67
+ st.stop()
68
+
69
+ # ─── Helper Functions ─────────────────────────────────────────────────────────
70
+
71
+ def generate_openai_image(prompt: str, style: str, strength: float) -> Image.Image:
72
+ """
73
+ Generates an image using the OpenAI API (GPT-4o).
74
+
75
+ Args:
76
+ prompt: The user's text prompt.
77
+ style: The selected style preset.
78
+ strength: The stylization strength (conceptually used in prompt).
79
+
80
+ Returns:
81
+ A PIL Image object.
82
+
83
+ Raises:
84
+ OpenAIError: If the API call fails.
85
+ IOError: If the image data cannot be processed.
86
+ """
87
+ logger.info(f"Requesting OpenAI image generation for prompt: '{prompt}' with style '{style}'")
88
+ full_prompt = f"Style: [{style}], Strength: [{strength:.2f}] - Generate the following medical illustration: {prompt}"
89
+ try:
90
+ response = client.images.generate(
91
+ model="dall-e-3", # Or "gpt-4o" if/when available via this endpoint. DALL-E 3 is current standard.
92
+ prompt=full_prompt,
93
+ size=IMAGE_SIZE,
94
+ quality="standard", # or "hd"
95
+ n=1,
96
+ response_format="url" # Or "b64_json" to avoid a second request
97
+ )
98
+ image_url = response.data[0].url
99
+ logger.info(f"Image generated successfully, URL: {image_url}")
100
+
101
+ # Fetch the image data from the URL
102
+ # Note: Using response_format="b64_json" would avoid this extra step
103
+ import requests # Need to import requests library
104
+ image_response = requests.get(image_url, timeout=30) # Add timeout
105
+ image_response.raise_for_status() # Check for HTTP errors
106
+
107
+ img_data = BytesIO(image_response.content)
108
+ img = Image.open(img_data)
109
+ return img
110
+
111
+ except OpenAIError as e:
112
+ logger.error(f"OpenAI API error: {e}")
113
+ st.error(f"❌ OpenAI API Error: {e}", icon="❌")
114
+ raise
115
+ except requests.exceptions.RequestException as e:
116
+ logger.error(f"Failed to download image from URL {image_url}: {e}")
117
+ st.error(f"❌ Network Error: Failed to download image. {e}", icon="❌")
118
+ raise IOError(f"Failed to download image: {e}") from e
119
+ except Exception as e:
120
+ logger.exception(f"An unexpected error occurred during OpenAI image generation: {e}")
121
+ st.error(f"❌ An unexpected error occurred: {e}", icon="❌")
122
+ raise
123
+
124
+
125
+ def generate_sd_image(prompt: str, style: str, strength: float) -> Image.Image:
126
+ """
127
+ Placeholder for generating an image using a Stable Diffusion LoRA model.
128
+ Replace this with your actual implementation.
129
+
130
+ Args:
131
+ prompt: The user's text prompt.
132
+ style: The selected style preset.
133
+ strength: The stylization strength.
134
+
135
+ Returns:
136
+ A PIL Image object (dummy implementation).
137
+
138
+ Raises:
139
+ NotImplementedError: As this is a placeholder.
140
+ """
141
+ logger.warning("Stable Diffusion LoRA model generation is not implemented. Returning placeholder.")
142
+ st.warning("🚧 Stable Diffusion LoRA generation is not yet implemented. Using placeholder.", icon="🚧")
143
+
144
+ # --- Placeholder Implementation ---
145
+ # Replace this with actual SD model call
146
+ # For now, create a simple dummy image with text
147
+ img = Image.new('RGB', (CANVAS_SIZE, CANVAS_SIZE), color = (210, 210, 210))
148
+ from PIL import ImageDraw
149
+ d = ImageDraw.Draw(img)
150
+ d.text((10,10), f"Stable Diffusion Placeholder\nStyle: {style}\nPrompt: {prompt[:50]}...", fill=(0,0,0))
151
+ # --- End Placeholder ---
152
+
153
+ # Simulate some processing time
154
+ import time
155
+ time.sleep(1)
156
+ return img
157
+ # raise NotImplementedError("Stable Diffusion LoRA generation is not yet available.")
158
+
159
+
160
+ def display_result(image: Image.Image, prompt: str, index: int, total: int) -> Optional[List[Dict[str, Any]]]:
161
+ """
162
+ Displays a generated image, download button, and annotation canvas.
163
+
164
+ Args:
165
+ image: The PIL Image to display.
166
+ prompt: The prompt used to generate the image.
167
+ index: The index of the current image in a batch.
168
+ total: The total number of images in the batch.
169
+
170
+ Returns:
171
+ Annotation data (list of dicts) if annotations were made, otherwise None.
172
+ """
173
+ st.image(image, caption=f"Result {index + 1}/{total}: {prompt}", use_column_width='always')
174
+
175
+ # Prepare image for download
176
+ buf = BytesIO()
177
+ image.save(buf, format="PNG")
178
+ buf.seek(0)
179
+
180
+ st.download_button(
181
+ label="⬇️ Download PNG",
182
+ data=buf,
183
+ file_name=f"medsketch_{index+1}_{prompt[:20].replace(' ', '_')}.png",
184
+ mime="image/png",
185
+ key=f"download_{index}"
186
+ )
187
+
188
+ # Annotation Canvas
189
+ st.markdown("**✏️ Annotate:**")
190
+ # Resize image for canvas if needed, maintaining aspect ratio (optional)
191
+ # For simplicity, we assume the canvas size matches desired annotation size
192
+ canvas_image = image.copy()
193
+ canvas_image.thumbnail((CANVAS_SIZE, CANVAS_SIZE))
194
+
195
+ canvas_result = st_canvas(
196
+ fill_color=ANNOTATION_COLOR,
197
+ stroke_width=ANNOTATION_STROKE_WIDTH,
198
+ background_image=canvas_image,
199
+ update_streamlit=True, # Update in real-time
200
+ height=canvas_image.height,
201
+ width=canvas_image.width,
202
+ drawing_mode="freedraw", # Or choose other modes like "line", "rect", etc.
203
+ key=f"canvas_{index}"
204
+ )
205
 
206
+ if canvas_result.json_data and canvas_result.json_data.get("objects"):
207
+ return canvas_result.json_data["objects"]
208
+ return None
 
 
209
 
210
+ # ─── Initialize Session State ───────────────────────────────────────────────
 
211
 
212
+ if SESSION_STATE_ANNOTATIONS not in st.session_state:
213
+ st.session_state[SESSION_STATE_ANNOTATIONS] = {} # Dict[prompt, List[annotation_objects]]
214
+ if SESSION_STATE_HISTORY not in st.session_state:
215
+ st.session_state[SESSION_STATE_HISTORY] = [] # List[Dict[str, Any]] storing generation results
216
+
217
+ # ─── Sidebar: Settings & Metadata ───────────────────────────────────────────
218
+
219
+ with st.sidebar:
220
+ st.header("βš™οΈ Generation Settings")
221
+ model_choice = st.selectbox(
222
+ "Select Model",
223
+ options=MODEL_OPTIONS,
224
+ index=MODEL_OPTIONS.index(DEFAULT_MODEL),
225
+ help="Choose the AI model for image generation."
226
+ )
227
+
228
+ style_preset = st.radio(
229
+ "Select Preset Style",
230
+ options=STYLE_PRESETS,
231
+ index=STYLE_PRESETS.index(DEFAULT_STYLE),
232
+ horizontal=True, # More compact layout
233
+ help="Apply a predefined visual style to the generation."
234
+ )
235
+ # Allow custom style input only if "Custom" is selected
236
+ custom_style_input = ""
237
+ if style_preset == "Custom":
238
+ custom_style_input = st.text_input("Enter Custom Style Description:", key="custom_style")
239
+ final_style = custom_style_input if style_preset == "Custom" else style_preset
240
+
241
+
242
+ strength = st.slider(
243
+ "Stylization Strength",
244
+ min_value=0.1,
245
+ max_value=1.0,
246
+ value=DEFAULT_STRENGTH,
247
+ step=0.05,
248
+ help="Controls how strongly the chosen style influences the result (conceptual)."
249
+ )
250
+
251
+ st.markdown("---")
252
+ st.header("πŸ“‹ Optional Metadata")
253
+ patient_id = st.text_input("Patient / Case ID", key="patient_id", help="Associate with a specific patient or case.")
254
+ roi = st.text_input("Region of Interest (ROI)", key="roi", help="Specify the anatomical region shown.")
255
+ umls_code = st.text_input("UMLS / SNOMED CT Code", key="umls_code", help="Link to relevant medical ontology codes.")
256
+
257
+ # Add a clear history button
258
+ st.markdown("---")
259
+ if st.button("⚠️ Clear History & Annotations", help="Removes all generated images and annotations from this session."):
260
+ st.session_state[SESSION_STATE_ANNOTATIONS] = {}
261
+ st.session_state[SESSION_STATE_HISTORY] = []
262
+ st.rerun() # Refresh the page to reflect cleared state
263
+
264
+ # ─── Main Application Area ───────────────────────────────────────────────────
265
+
266
+ st.title(APP_TITLE)
267
+ st.markdown("Generate medical illustrations from text descriptions using AI. Annotate and export your results.")
268
+
269
+ # --- Prompt Input Area ---
270
+ prompt_input_area = st.container()
271
+ with prompt_input_area:
272
+ st.subheader("πŸ“ Enter Prompt(s)")
273
+ st.caption("Enter one prompt per line to generate multiple images in a batch.")
274
+ raw_prompts = st.text_area(
275
+ "Describe the medical diagram(s) you need:",
276
  placeholder=(
277
+ "Example 1: A sagittal view of the human knee joint, labeling the ACL, PCL, meniscus, femur, and tibia.\n"
278
+ "Example 2: High-power field H&E stain of lung adenocarcinoma showing glandular formation.\n"
279
+ "Example 3: Immunohistochemistry (IHC) stain for PD-L1 in tonsil tissue, showing positive staining on immune cells."
280
  ),
281
+ height=150, # Slightly larger height
282
+ label_visibility="collapsed"
283
  )
284
+ prompts: List[str] = [p.strip() for p in raw_prompts.splitlines() if p.strip()]
285
 
286
+ # --- Generation Trigger ---
287
+ generate_button = st.button(
288
+ f"πŸš€ Generate Diagram{'s' if len(prompts) > 1 else ''}",
289
+ type="primary",
290
+ disabled=not prompts, # Disable if no prompts
291
+ use_container_width=True
292
+ )
293
+
294
+ # --- Generation and Display Area ---
295
+ results_area = st.container()
296
+ if generate_button:
297
  if not prompts:
298
+ st.warning("⚠️ Please enter at least one prompt description.", icon="⚠️")
299
  else:
300
+ logger.info(f"Starting generation for {len(prompts)} prompts using model '{model_choice}'.")
301
+ num_prompts = len(prompts)
302
+ max_cols = 3 # Adjust number of columns based on screen width or preference
303
+ cols = st.columns(min(max_cols, num_prompts))
304
+
305
+ # Use a progress bar for batch generation
306
+ progress_bar = st.progress(0, text=f"Initializing generation...")
307
+
308
  for i, prompt in enumerate(prompts):
309
+ col_index = i % max_cols
310
+ with cols[col_index]:
311
+ st.markdown(f"--- \n**Processing: {i+1}/{num_prompts}**")
312
+ spinner_msg = f"Generating image {i+1}/{num_prompts} for prompt: \"{prompt[:50]}...\""
313
+ with st.spinner(spinner_msg):
314
+ try:
315
+ # Select generation function based on model choice
316
+ if model_choice == DEFAULT_MODEL:
317
+ generated_image = generate_openai_image(prompt, final_style, strength)
318
+ elif model_choice == STABLE_DIFFUSION_MODEL:
319
+ generated_image = generate_sd_image(prompt, final_style, strength)
320
+ else:
321
+ st.error(f"Unknown model selected: {model_choice}", icon="❌")
322
+ continue # Skip to next prompt
323
+
324
+ # Display result and get annotations
325
+ annotations = display_result(generated_image, prompt, i, num_prompts)
326
+
327
+ # Store results and annotations in session state
328
+ result_data = {
329
+ "prompt": prompt,
330
+ "model": model_choice,
331
+ "style": final_style,
332
+ "strength": strength,
333
+ "metadata": {
334
+ "patient_id": patient_id,
335
+ "roi": roi,
336
+ "umls_code": umls_code,
337
+ },
338
+ # Store image data efficiently (e.g., as base64 or keep PIL object if memory allows)
339
+ # For simplicity here, we might just store prompt and annotations.
340
+ # Storing images in session state can consume a lot of memory.
341
+ # Let's store the prompt reference and annotations.
342
+ "image_ref_index": i # Reference to this generation instance
343
+ }
344
+ st.session_state[SESSION_STATE_HISTORY].append(result_data)
345
+
346
+ if annotations:
347
+ st.session_state[SESSION_STATE_ANNOTATIONS][prompt] = annotations
348
+ st.success(f"Annotations saved for prompt {i+1}.", icon="βœ…")
349
+
350
+ except (OpenAIError, IOError, NotImplementedError, Exception) as e:
351
+ # Errors are logged and displayed by the generation functions
352
+ st.error(f"Failed to generate image for prompt: '{prompt}'. Error: {e}", icon="πŸ”₯")
353
+ # Optionally add failed attempts to history?
354
+ st.session_state[SESSION_STATE_HISTORY].append({
355
+ "prompt": prompt, "status": "failed", "error": str(e)
356
+ })
357
+
358
+ # Update progress bar
359
+ progress_val = (i + 1) / num_prompts
360
+ progress_bar.progress(progress_val, text=f"Generated {i+1}/{num_prompts} images...")
361
+
362
+ progress_bar.progress(1.0, text="Batch generation complete!")
363
+ st.toast(f"Finished generating {num_prompts} image(s)!", icon="πŸŽ‰")
364
+ # Explicitly clear the progress bar after completion
365
+ # (Streamlit often handles this, but explicit removal can be cleaner)
366
+ # Consider removing or hiding the progress bar element if needed after completion.
367
+
368
+
369
+ # ─── History & Exports Section ───────────────────────────────────────────────
370
+
371
+ history_area = st.container()
372
+ with history_area:
373
+ # Use session state history which is more robust
374
+ if st.session_state[SESSION_STATE_HISTORY]:
375
+ st.markdown("---")
376
+ st.subheader("πŸ“š Session History & Annotations")
377
+ st.caption("Review generated images (if stored) and their annotations from this session.")
378
+
379
+ # Display stored history (simplified view focusing on annotations)
380
+ for idx, item in enumerate(st.session_state[SESSION_STATE_HISTORY]):
381
+ if item.get("status") == "failed":
382
+ st.warning(f"**Prompt {idx+1} (Failed):** {item['prompt']} \n *Error: {item['error']}*", icon="⚠️")
383
+ else:
384
+ prompt_key = item["prompt"]
385
+ st.markdown(f"**Prompt {idx+1}:** `{prompt_key}`")
386
+ st.markdown(f"*Model: {item['model']}, Style: {item['style']}*")
387
+ # Display metadata if present
388
+ meta = item.get('metadata', {})
389
+ if any(meta.values()):
390
+ meta_str = ", ".join([f"{k}: {v}" for k, v in meta.items() if v])
391
+ st.markdown(f"*Metadata: {meta_str}*")
392
+
393
+ # Check for annotations for this prompt
394
+ annotations = st.session_state[SESSION_STATE_ANNOTATIONS].get(prompt_key)
395
+ if annotations:
396
+ with st.expander(f"View {len(annotations)} Annotation(s)"):
397
+ st.json(annotations)
398
  else:
399
+ st.caption("_(No annotations made for this item yet)_")
400
+ st.markdown("---") # Separator between history items
401
+
402
+
403
+ # --- Export Annotations ---
404
+ if st.session_state[SESSION_STATE_ANNOTATIONS]:
405
+ st.markdown("---")
406
+ st.subheader("⬇️ Export Annotations")
407
+ try:
408
+ # Prepare data with metadata included per annotation set
409
+ export_data = {}
410
+ # Find corresponding history item to enrich annotation export
411
+ history_map = {item['prompt']: item for item in st.session_state[SESSION_STATE_HISTORY] if item.get('status') != 'failed'}
412
+
413
+ for prompt, ann_objs in st.session_state[SESSION_STATE_ANNOTATIONS].items():
414
+ history_item = history_map.get(prompt)
415
+ export_data[prompt] = {
416
+ "annotations": ann_objs,
417
+ "generation_details": {
418
+ "model": history_item.get('model'),
419
+ "style": history_item.get('style'),
420
+ "strength": history_item.get('strength'),
421
+ } if history_item else None,
422
+ "metadata": history_item.get('metadata') if history_item else None
423
+ }
424
+
425
+ json_data = json.dumps(export_data, indent=2)
426
  st.download_button(
427
+ label="⬇️ Export All Annotations (JSON)",
428
+ data=json_data,
429
+ file_name="medsketch_session_annotations.json",
430
+ mime="application/json",
431
+ help="Download all annotations made during this session, including associated metadata."
 
 
 
 
 
 
 
 
 
 
 
 
432
  )
433
+ except Exception as e:
434
+ st.error(f"Failed to prepare annotations for download: {e}")
435
+ logger.error(f"Error preparing JSON export: {e}")
 
436
 
437
+ elif generate_button: # If generate was clicked but history is empty (e.g., all failed)
438
+ st.info("No successful generations or annotations in the current session yet.")
439
+
440
+ # Add a footer (optional)
441
+ st.markdown("---")
442
+ st.caption("MedSketch AI - Powered by Streamlit and OpenAI")