Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,111 +1,442 @@
|
|
1 |
# app.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import os
|
|
|
|
|
|
|
|
|
|
|
3 |
import streamlit as st
|
4 |
from streamlit_drawable_canvas import st_canvas
|
5 |
from PIL import Image
|
6 |
import openai
|
7 |
-
from
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
# βββ 1. Configuration & Secrets βββββββββββββββββββββββββββββββββββββββββββββ
|
11 |
-
openai.api_key = st.secrets["OPENAI_API_KEY"] # or os.getenv("OPENAI_API_KEY")
|
12 |
st.set_page_config(
|
13 |
-
page_title=
|
14 |
layout="wide",
|
15 |
initial_sidebar_state="expanded",
|
|
|
|
|
|
|
|
|
|
|
16 |
)
|
17 |
|
18 |
-
#
|
19 |
-
st.
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
roi = st.sidebar.text_input("Region of Interest")
|
35 |
-
umls_code = st.sidebar.text_input("UMLS / SNOMED CT Code")
|
36 |
|
37 |
-
# βββ
|
38 |
-
st.title("πΌοΈ MedSketchβ―AI β Advanced Clinical Diagram Generator")
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
placeholder=(
|
44 |
-
"
|
45 |
-
"
|
|
|
46 |
),
|
47 |
-
height=
|
|
|
48 |
)
|
49 |
-
prompts = [p.strip() for p in
|
50 |
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
if not prompts:
|
53 |
-
st.
|
54 |
else:
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
for i, prompt in enumerate(prompts):
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
else:
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
st.download_button(
|
76 |
-
label="β¬οΈ
|
77 |
-
data=
|
78 |
-
file_name=
|
79 |
-
mime="
|
80 |
-
|
81 |
-
|
82 |
-
# βββ Annotation Canvas βββββββββββββββββββββββββββ
|
83 |
-
st.markdown("**βοΈ Annotate:**")
|
84 |
-
canvas_res = st_canvas(
|
85 |
-
fill_color="rgba(255, 0, 0, 0.3)", # annotation color
|
86 |
-
stroke_width=2,
|
87 |
-
background_image=img,
|
88 |
-
update_streamlit=True,
|
89 |
-
height=512,
|
90 |
-
width=512,
|
91 |
-
drawing_mode="freedraw",
|
92 |
-
key=f"canvas_{i}"
|
93 |
)
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
st.session_state.setdefault("annotations", {})[prompt] = ann
|
98 |
|
99 |
-
#
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
st.json(objs)
|
106 |
-
st.download_button(
|
107 |
-
"β¬οΈ Export All Annotations (JSON)",
|
108 |
-
data=json.dumps(st.session_state["annotations"], indent=2),
|
109 |
-
file_name="medsketch_annotations.json",
|
110 |
-
mime="application/json"
|
111 |
-
)
|
|
|
1 |
# app.py
|
2 |
+
"""
|
3 |
+
MedSketch AI: Advanced Clinical Diagram Generator
|
4 |
+
|
5 |
+
A Streamlit application leveraging AI models (GPT-4o, potentially Stable Diffusion)
|
6 |
+
to generate medical diagrams based on user prompts, with options for styling,
|
7 |
+
metadata association, and annotations.
|
8 |
+
"""
|
9 |
+
|
10 |
import os
|
11 |
+
import json
|
12 |
+
import logging
|
13 |
+
from io import BytesIO
|
14 |
+
from typing import List, Dict, Any, Optional, Tuple
|
15 |
+
|
16 |
import streamlit as st
|
17 |
from streamlit_drawable_canvas import st_canvas
|
18 |
from PIL import Image
|
19 |
import openai
|
20 |
+
from openai import OpenAI, OpenAIError # Use modern OpenAI client and error types
|
21 |
+
|
22 |
+
# βββ Constants βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
23 |
+
|
24 |
+
APP_TITLE = "MedSketch AI β Advanced Clinical Diagram Generator"
|
25 |
+
DEFAULT_MODEL = "GPT-4o (Vision)" # Updated model name
|
26 |
+
STABLE_DIFFUSION_MODEL = "Stable Diffusion LoRA" # Placeholder name
|
27 |
+
MODEL_OPTIONS = [DEFAULT_MODEL, STABLE_DIFFUSION_MODEL]
|
28 |
+
STYLE_PRESETS = ["Anatomical Diagram", "H&E Histology", "IHC Pathology", "Custom"]
|
29 |
+
DEFAULT_STYLE = "Anatomical Diagram"
|
30 |
+
DEFAULT_STRENGTH = 0.7
|
31 |
+
IMAGE_SIZE = "1024x1024"
|
32 |
+
CANVAS_SIZE = 512
|
33 |
+
ANNOTATION_COLOR = "rgba(255, 0, 0, 0.3)" # Red with transparency
|
34 |
+
ANNOTATION_STROKE_WIDTH = 2
|
35 |
+
SESSION_STATE_ANNOTATIONS = "medsketch_annotations"
|
36 |
+
SESSION_STATE_HISTORY = "medsketch_history" # Store generated images too
|
37 |
+
|
38 |
+
# βββ Setup & Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
39 |
+
|
40 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
41 |
+
logger = logging.getLogger(__name__)
|
42 |
|
|
|
|
|
43 |
st.set_page_config(
|
44 |
+
page_title=APP_TITLE,
|
45 |
layout="wide",
|
46 |
initial_sidebar_state="expanded",
|
47 |
+
menu_items={
|
48 |
+
'About': f"{APP_TITLE} - Generates medical diagrams using AI.",
|
49 |
+
'Get Help': None, # Add a link if you have one
|
50 |
+
'Report a bug': None # Add a link if you have one
|
51 |
+
}
|
52 |
)
|
53 |
|
54 |
+
# Initialize OpenAI Client (Best Practice)
|
55 |
+
# Use st.secrets for deployment, fallback to env var for local dev
|
56 |
+
api_key = st.secrets.get("OPENAI_API_KEY", os.getenv("OPENAI_API_KEY"))
|
57 |
+
if not api_key:
|
58 |
+
st.error("π¨ OpenAI API Key not found! Please set it in Streamlit secrets or environment variables.", icon="π¨")
|
59 |
+
st.stop() # Halt execution if no key
|
60 |
+
|
61 |
+
try:
|
62 |
+
client = OpenAI(api_key=api_key)
|
63 |
+
logger.info("OpenAI client initialized successfully.")
|
64 |
+
except Exception as e:
|
65 |
+
st.error(f"π¨ Failed to initialize OpenAI client: {e}", icon="π¨")
|
66 |
+
logger.exception("OpenAI client initialization failed.")
|
67 |
+
st.stop()
|
68 |
+
|
69 |
+
# βββ Helper Functions βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
70 |
+
|
71 |
+
def generate_openai_image(prompt: str, style: str, strength: float) -> Image.Image:
|
72 |
+
"""
|
73 |
+
Generates an image using the OpenAI API (GPT-4o).
|
74 |
+
|
75 |
+
Args:
|
76 |
+
prompt: The user's text prompt.
|
77 |
+
style: The selected style preset.
|
78 |
+
strength: The stylization strength (conceptually used in prompt).
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
A PIL Image object.
|
82 |
+
|
83 |
+
Raises:
|
84 |
+
OpenAIError: If the API call fails.
|
85 |
+
IOError: If the image data cannot be processed.
|
86 |
+
"""
|
87 |
+
logger.info(f"Requesting OpenAI image generation for prompt: '{prompt}' with style '{style}'")
|
88 |
+
full_prompt = f"Style: [{style}], Strength: [{strength:.2f}] - Generate the following medical illustration: {prompt}"
|
89 |
+
try:
|
90 |
+
response = client.images.generate(
|
91 |
+
model="dall-e-3", # Or "gpt-4o" if/when available via this endpoint. DALL-E 3 is current standard.
|
92 |
+
prompt=full_prompt,
|
93 |
+
size=IMAGE_SIZE,
|
94 |
+
quality="standard", # or "hd"
|
95 |
+
n=1,
|
96 |
+
response_format="url" # Or "b64_json" to avoid a second request
|
97 |
+
)
|
98 |
+
image_url = response.data[0].url
|
99 |
+
logger.info(f"Image generated successfully, URL: {image_url}")
|
100 |
+
|
101 |
+
# Fetch the image data from the URL
|
102 |
+
# Note: Using response_format="b64_json" would avoid this extra step
|
103 |
+
import requests # Need to import requests library
|
104 |
+
image_response = requests.get(image_url, timeout=30) # Add timeout
|
105 |
+
image_response.raise_for_status() # Check for HTTP errors
|
106 |
+
|
107 |
+
img_data = BytesIO(image_response.content)
|
108 |
+
img = Image.open(img_data)
|
109 |
+
return img
|
110 |
+
|
111 |
+
except OpenAIError as e:
|
112 |
+
logger.error(f"OpenAI API error: {e}")
|
113 |
+
st.error(f"β OpenAI API Error: {e}", icon="β")
|
114 |
+
raise
|
115 |
+
except requests.exceptions.RequestException as e:
|
116 |
+
logger.error(f"Failed to download image from URL {image_url}: {e}")
|
117 |
+
st.error(f"β Network Error: Failed to download image. {e}", icon="β")
|
118 |
+
raise IOError(f"Failed to download image: {e}") from e
|
119 |
+
except Exception as e:
|
120 |
+
logger.exception(f"An unexpected error occurred during OpenAI image generation: {e}")
|
121 |
+
st.error(f"β An unexpected error occurred: {e}", icon="β")
|
122 |
+
raise
|
123 |
+
|
124 |
+
|
125 |
+
def generate_sd_image(prompt: str, style: str, strength: float) -> Image.Image:
|
126 |
+
"""
|
127 |
+
Placeholder for generating an image using a Stable Diffusion LoRA model.
|
128 |
+
Replace this with your actual implementation.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
prompt: The user's text prompt.
|
132 |
+
style: The selected style preset.
|
133 |
+
strength: The stylization strength.
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
A PIL Image object (dummy implementation).
|
137 |
+
|
138 |
+
Raises:
|
139 |
+
NotImplementedError: As this is a placeholder.
|
140 |
+
"""
|
141 |
+
logger.warning("Stable Diffusion LoRA model generation is not implemented. Returning placeholder.")
|
142 |
+
st.warning("π§ Stable Diffusion LoRA generation is not yet implemented. Using placeholder.", icon="π§")
|
143 |
+
|
144 |
+
# --- Placeholder Implementation ---
|
145 |
+
# Replace this with actual SD model call
|
146 |
+
# For now, create a simple dummy image with text
|
147 |
+
img = Image.new('RGB', (CANVAS_SIZE, CANVAS_SIZE), color = (210, 210, 210))
|
148 |
+
from PIL import ImageDraw
|
149 |
+
d = ImageDraw.Draw(img)
|
150 |
+
d.text((10,10), f"Stable Diffusion Placeholder\nStyle: {style}\nPrompt: {prompt[:50]}...", fill=(0,0,0))
|
151 |
+
# --- End Placeholder ---
|
152 |
+
|
153 |
+
# Simulate some processing time
|
154 |
+
import time
|
155 |
+
time.sleep(1)
|
156 |
+
return img
|
157 |
+
# raise NotImplementedError("Stable Diffusion LoRA generation is not yet available.")
|
158 |
+
|
159 |
+
|
160 |
+
def display_result(image: Image.Image, prompt: str, index: int, total: int) -> Optional[List[Dict[str, Any]]]:
|
161 |
+
"""
|
162 |
+
Displays a generated image, download button, and annotation canvas.
|
163 |
+
|
164 |
+
Args:
|
165 |
+
image: The PIL Image to display.
|
166 |
+
prompt: The prompt used to generate the image.
|
167 |
+
index: The index of the current image in a batch.
|
168 |
+
total: The total number of images in the batch.
|
169 |
+
|
170 |
+
Returns:
|
171 |
+
Annotation data (list of dicts) if annotations were made, otherwise None.
|
172 |
+
"""
|
173 |
+
st.image(image, caption=f"Result {index + 1}/{total}: {prompt}", use_column_width='always')
|
174 |
+
|
175 |
+
# Prepare image for download
|
176 |
+
buf = BytesIO()
|
177 |
+
image.save(buf, format="PNG")
|
178 |
+
buf.seek(0)
|
179 |
+
|
180 |
+
st.download_button(
|
181 |
+
label="β¬οΈ Download PNG",
|
182 |
+
data=buf,
|
183 |
+
file_name=f"medsketch_{index+1}_{prompt[:20].replace(' ', '_')}.png",
|
184 |
+
mime="image/png",
|
185 |
+
key=f"download_{index}"
|
186 |
+
)
|
187 |
+
|
188 |
+
# Annotation Canvas
|
189 |
+
st.markdown("**βοΈ Annotate:**")
|
190 |
+
# Resize image for canvas if needed, maintaining aspect ratio (optional)
|
191 |
+
# For simplicity, we assume the canvas size matches desired annotation size
|
192 |
+
canvas_image = image.copy()
|
193 |
+
canvas_image.thumbnail((CANVAS_SIZE, CANVAS_SIZE))
|
194 |
+
|
195 |
+
canvas_result = st_canvas(
|
196 |
+
fill_color=ANNOTATION_COLOR,
|
197 |
+
stroke_width=ANNOTATION_STROKE_WIDTH,
|
198 |
+
background_image=canvas_image,
|
199 |
+
update_streamlit=True, # Update in real-time
|
200 |
+
height=canvas_image.height,
|
201 |
+
width=canvas_image.width,
|
202 |
+
drawing_mode="freedraw", # Or choose other modes like "line", "rect", etc.
|
203 |
+
key=f"canvas_{index}"
|
204 |
+
)
|
205 |
|
206 |
+
if canvas_result.json_data and canvas_result.json_data.get("objects"):
|
207 |
+
return canvas_result.json_data["objects"]
|
208 |
+
return None
|
|
|
|
|
209 |
|
210 |
+
# βββ Initialize Session State βββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
211 |
|
212 |
+
if SESSION_STATE_ANNOTATIONS not in st.session_state:
|
213 |
+
st.session_state[SESSION_STATE_ANNOTATIONS] = {} # Dict[prompt, List[annotation_objects]]
|
214 |
+
if SESSION_STATE_HISTORY not in st.session_state:
|
215 |
+
st.session_state[SESSION_STATE_HISTORY] = [] # List[Dict[str, Any]] storing generation results
|
216 |
+
|
217 |
+
# βββ Sidebar: Settings & Metadata βββββββββββββββββββββββββββββββββββββββββββ
|
218 |
+
|
219 |
+
with st.sidebar:
|
220 |
+
st.header("βοΈ Generation Settings")
|
221 |
+
model_choice = st.selectbox(
|
222 |
+
"Select Model",
|
223 |
+
options=MODEL_OPTIONS,
|
224 |
+
index=MODEL_OPTIONS.index(DEFAULT_MODEL),
|
225 |
+
help="Choose the AI model for image generation."
|
226 |
+
)
|
227 |
+
|
228 |
+
style_preset = st.radio(
|
229 |
+
"Select Preset Style",
|
230 |
+
options=STYLE_PRESETS,
|
231 |
+
index=STYLE_PRESETS.index(DEFAULT_STYLE),
|
232 |
+
horizontal=True, # More compact layout
|
233 |
+
help="Apply a predefined visual style to the generation."
|
234 |
+
)
|
235 |
+
# Allow custom style input only if "Custom" is selected
|
236 |
+
custom_style_input = ""
|
237 |
+
if style_preset == "Custom":
|
238 |
+
custom_style_input = st.text_input("Enter Custom Style Description:", key="custom_style")
|
239 |
+
final_style = custom_style_input if style_preset == "Custom" else style_preset
|
240 |
+
|
241 |
+
|
242 |
+
strength = st.slider(
|
243 |
+
"Stylization Strength",
|
244 |
+
min_value=0.1,
|
245 |
+
max_value=1.0,
|
246 |
+
value=DEFAULT_STRENGTH,
|
247 |
+
step=0.05,
|
248 |
+
help="Controls how strongly the chosen style influences the result (conceptual)."
|
249 |
+
)
|
250 |
+
|
251 |
+
st.markdown("---")
|
252 |
+
st.header("π Optional Metadata")
|
253 |
+
patient_id = st.text_input("Patient / Case ID", key="patient_id", help="Associate with a specific patient or case.")
|
254 |
+
roi = st.text_input("Region of Interest (ROI)", key="roi", help="Specify the anatomical region shown.")
|
255 |
+
umls_code = st.text_input("UMLS / SNOMED CT Code", key="umls_code", help="Link to relevant medical ontology codes.")
|
256 |
+
|
257 |
+
# Add a clear history button
|
258 |
+
st.markdown("---")
|
259 |
+
if st.button("β οΈ Clear History & Annotations", help="Removes all generated images and annotations from this session."):
|
260 |
+
st.session_state[SESSION_STATE_ANNOTATIONS] = {}
|
261 |
+
st.session_state[SESSION_STATE_HISTORY] = []
|
262 |
+
st.rerun() # Refresh the page to reflect cleared state
|
263 |
+
|
264 |
+
# βββ Main Application Area βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
265 |
+
|
266 |
+
st.title(APP_TITLE)
|
267 |
+
st.markdown("Generate medical illustrations from text descriptions using AI. Annotate and export your results.")
|
268 |
+
|
269 |
+
# --- Prompt Input Area ---
|
270 |
+
prompt_input_area = st.container()
|
271 |
+
with prompt_input_area:
|
272 |
+
st.subheader("π Enter Prompt(s)")
|
273 |
+
st.caption("Enter one prompt per line to generate multiple images in a batch.")
|
274 |
+
raw_prompts = st.text_area(
|
275 |
+
"Describe the medical diagram(s) you need:",
|
276 |
placeholder=(
|
277 |
+
"Example 1: A sagittal view of the human knee joint, labeling the ACL, PCL, meniscus, femur, and tibia.\n"
|
278 |
+
"Example 2: High-power field H&E stain of lung adenocarcinoma showing glandular formation.\n"
|
279 |
+
"Example 3: Immunohistochemistry (IHC) stain for PD-L1 in tonsil tissue, showing positive staining on immune cells."
|
280 |
),
|
281 |
+
height=150, # Slightly larger height
|
282 |
+
label_visibility="collapsed"
|
283 |
)
|
284 |
+
prompts: List[str] = [p.strip() for p in raw_prompts.splitlines() if p.strip()]
|
285 |
|
286 |
+
# --- Generation Trigger ---
|
287 |
+
generate_button = st.button(
|
288 |
+
f"π Generate Diagram{'s' if len(prompts) > 1 else ''}",
|
289 |
+
type="primary",
|
290 |
+
disabled=not prompts, # Disable if no prompts
|
291 |
+
use_container_width=True
|
292 |
+
)
|
293 |
+
|
294 |
+
# --- Generation and Display Area ---
|
295 |
+
results_area = st.container()
|
296 |
+
if generate_button:
|
297 |
if not prompts:
|
298 |
+
st.warning("β οΈ Please enter at least one prompt description.", icon="β οΈ")
|
299 |
else:
|
300 |
+
logger.info(f"Starting generation for {len(prompts)} prompts using model '{model_choice}'.")
|
301 |
+
num_prompts = len(prompts)
|
302 |
+
max_cols = 3 # Adjust number of columns based on screen width or preference
|
303 |
+
cols = st.columns(min(max_cols, num_prompts))
|
304 |
+
|
305 |
+
# Use a progress bar for batch generation
|
306 |
+
progress_bar = st.progress(0, text=f"Initializing generation...")
|
307 |
+
|
308 |
for i, prompt in enumerate(prompts):
|
309 |
+
col_index = i % max_cols
|
310 |
+
with cols[col_index]:
|
311 |
+
st.markdown(f"--- \n**Processing: {i+1}/{num_prompts}**")
|
312 |
+
spinner_msg = f"Generating image {i+1}/{num_prompts} for prompt: \"{prompt[:50]}...\""
|
313 |
+
with st.spinner(spinner_msg):
|
314 |
+
try:
|
315 |
+
# Select generation function based on model choice
|
316 |
+
if model_choice == DEFAULT_MODEL:
|
317 |
+
generated_image = generate_openai_image(prompt, final_style, strength)
|
318 |
+
elif model_choice == STABLE_DIFFUSION_MODEL:
|
319 |
+
generated_image = generate_sd_image(prompt, final_style, strength)
|
320 |
+
else:
|
321 |
+
st.error(f"Unknown model selected: {model_choice}", icon="β")
|
322 |
+
continue # Skip to next prompt
|
323 |
+
|
324 |
+
# Display result and get annotations
|
325 |
+
annotations = display_result(generated_image, prompt, i, num_prompts)
|
326 |
+
|
327 |
+
# Store results and annotations in session state
|
328 |
+
result_data = {
|
329 |
+
"prompt": prompt,
|
330 |
+
"model": model_choice,
|
331 |
+
"style": final_style,
|
332 |
+
"strength": strength,
|
333 |
+
"metadata": {
|
334 |
+
"patient_id": patient_id,
|
335 |
+
"roi": roi,
|
336 |
+
"umls_code": umls_code,
|
337 |
+
},
|
338 |
+
# Store image data efficiently (e.g., as base64 or keep PIL object if memory allows)
|
339 |
+
# For simplicity here, we might just store prompt and annotations.
|
340 |
+
# Storing images in session state can consume a lot of memory.
|
341 |
+
# Let's store the prompt reference and annotations.
|
342 |
+
"image_ref_index": i # Reference to this generation instance
|
343 |
+
}
|
344 |
+
st.session_state[SESSION_STATE_HISTORY].append(result_data)
|
345 |
+
|
346 |
+
if annotations:
|
347 |
+
st.session_state[SESSION_STATE_ANNOTATIONS][prompt] = annotations
|
348 |
+
st.success(f"Annotations saved for prompt {i+1}.", icon="β
")
|
349 |
+
|
350 |
+
except (OpenAIError, IOError, NotImplementedError, Exception) as e:
|
351 |
+
# Errors are logged and displayed by the generation functions
|
352 |
+
st.error(f"Failed to generate image for prompt: '{prompt}'. Error: {e}", icon="π₯")
|
353 |
+
# Optionally add failed attempts to history?
|
354 |
+
st.session_state[SESSION_STATE_HISTORY].append({
|
355 |
+
"prompt": prompt, "status": "failed", "error": str(e)
|
356 |
+
})
|
357 |
+
|
358 |
+
# Update progress bar
|
359 |
+
progress_val = (i + 1) / num_prompts
|
360 |
+
progress_bar.progress(progress_val, text=f"Generated {i+1}/{num_prompts} images...")
|
361 |
+
|
362 |
+
progress_bar.progress(1.0, text="Batch generation complete!")
|
363 |
+
st.toast(f"Finished generating {num_prompts} image(s)!", icon="π")
|
364 |
+
# Explicitly clear the progress bar after completion
|
365 |
+
# (Streamlit often handles this, but explicit removal can be cleaner)
|
366 |
+
# Consider removing or hiding the progress bar element if needed after completion.
|
367 |
+
|
368 |
+
|
369 |
+
# βββ History & Exports Section βββββββββββββββββββββββββββββββββββββββββββββββ
|
370 |
+
|
371 |
+
history_area = st.container()
|
372 |
+
with history_area:
|
373 |
+
# Use session state history which is more robust
|
374 |
+
if st.session_state[SESSION_STATE_HISTORY]:
|
375 |
+
st.markdown("---")
|
376 |
+
st.subheader("π Session History & Annotations")
|
377 |
+
st.caption("Review generated images (if stored) and their annotations from this session.")
|
378 |
+
|
379 |
+
# Display stored history (simplified view focusing on annotations)
|
380 |
+
for idx, item in enumerate(st.session_state[SESSION_STATE_HISTORY]):
|
381 |
+
if item.get("status") == "failed":
|
382 |
+
st.warning(f"**Prompt {idx+1} (Failed):** {item['prompt']} \n *Error: {item['error']}*", icon="β οΈ")
|
383 |
+
else:
|
384 |
+
prompt_key = item["prompt"]
|
385 |
+
st.markdown(f"**Prompt {idx+1}:** `{prompt_key}`")
|
386 |
+
st.markdown(f"*Model: {item['model']}, Style: {item['style']}*")
|
387 |
+
# Display metadata if present
|
388 |
+
meta = item.get('metadata', {})
|
389 |
+
if any(meta.values()):
|
390 |
+
meta_str = ", ".join([f"{k}: {v}" for k, v in meta.items() if v])
|
391 |
+
st.markdown(f"*Metadata: {meta_str}*")
|
392 |
+
|
393 |
+
# Check for annotations for this prompt
|
394 |
+
annotations = st.session_state[SESSION_STATE_ANNOTATIONS].get(prompt_key)
|
395 |
+
if annotations:
|
396 |
+
with st.expander(f"View {len(annotations)} Annotation(s)"):
|
397 |
+
st.json(annotations)
|
398 |
else:
|
399 |
+
st.caption("_(No annotations made for this item yet)_")
|
400 |
+
st.markdown("---") # Separator between history items
|
401 |
+
|
402 |
+
|
403 |
+
# --- Export Annotations ---
|
404 |
+
if st.session_state[SESSION_STATE_ANNOTATIONS]:
|
405 |
+
st.markdown("---")
|
406 |
+
st.subheader("β¬οΈ Export Annotations")
|
407 |
+
try:
|
408 |
+
# Prepare data with metadata included per annotation set
|
409 |
+
export_data = {}
|
410 |
+
# Find corresponding history item to enrich annotation export
|
411 |
+
history_map = {item['prompt']: item for item in st.session_state[SESSION_STATE_HISTORY] if item.get('status') != 'failed'}
|
412 |
+
|
413 |
+
for prompt, ann_objs in st.session_state[SESSION_STATE_ANNOTATIONS].items():
|
414 |
+
history_item = history_map.get(prompt)
|
415 |
+
export_data[prompt] = {
|
416 |
+
"annotations": ann_objs,
|
417 |
+
"generation_details": {
|
418 |
+
"model": history_item.get('model'),
|
419 |
+
"style": history_item.get('style'),
|
420 |
+
"strength": history_item.get('strength'),
|
421 |
+
} if history_item else None,
|
422 |
+
"metadata": history_item.get('metadata') if history_item else None
|
423 |
+
}
|
424 |
+
|
425 |
+
json_data = json.dumps(export_data, indent=2)
|
426 |
st.download_button(
|
427 |
+
label="β¬οΈ Export All Annotations (JSON)",
|
428 |
+
data=json_data,
|
429 |
+
file_name="medsketch_session_annotations.json",
|
430 |
+
mime="application/json",
|
431 |
+
help="Download all annotations made during this session, including associated metadata."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
432 |
)
|
433 |
+
except Exception as e:
|
434 |
+
st.error(f"Failed to prepare annotations for download: {e}")
|
435 |
+
logger.error(f"Error preparing JSON export: {e}")
|
|
|
436 |
|
437 |
+
elif generate_button: # If generate was clicked but history is empty (e.g., all failed)
|
438 |
+
st.info("No successful generations or annotations in the current session yet.")
|
439 |
+
|
440 |
+
# Add a footer (optional)
|
441 |
+
st.markdown("---")
|
442 |
+
st.caption("MedSketch AI - Powered by Streamlit and OpenAI")
|
|
|
|
|
|
|
|
|
|
|
|
|
|