Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,388 +1,111 @@
|
|
1 |
-
#
|
|
|
2 |
import streamlit as st
|
3 |
-
|
4 |
-
import chromadb
|
5 |
-
from chromadb.utils import embedding_functions
|
6 |
from PIL import Image
|
7 |
-
import
|
8 |
-
import
|
9 |
-
import
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
|
50 |
-
]
|
51 |
-
GEMINI_ANALYSIS_PROMPT = """Analyze this medical image (e.g., pathology slide, diagram, scan).
|
52 |
-
Describe the key visual features relevant to a medical context.
|
53 |
-
Identify potential:
|
54 |
-
- Diseases or conditions indicated
|
55 |
-
- Pathological findings (e.g., cellular morphology, tissue structure, staining patterns)
|
56 |
-
- Visible cell types
|
57 |
-
- Relevant biomarkers (if inferable from staining or morphology)
|
58 |
-
- Anatomical context (if discernible)
|
59 |
-
|
60 |
-
Be concise and focus primarily on visually evident information. Avoid definitive diagnoses.
|
61 |
-
Structure the output clearly, perhaps using bullet points for findings.
|
62 |
-
"""
|
63 |
-
|
64 |
-
# Chroma DB Configuration
|
65 |
-
CHROMA_PATH = "chroma_data_biobert" # Changed path to reflect model change
|
66 |
-
COLLECTION_NAME = "medical_docs_biobert" # Changed collection name
|
67 |
-
|
68 |
-
# --- Embedding Model Selection ---
|
69 |
-
# Using BioBERT v1.1 - Good domain knowledge, but potentially suboptimal for *semantic similarity search*.
|
70 |
-
# Default pooling (likely CLS token) will be used by sentence-transformers.
|
71 |
-
# Consider models fine-tuned for sentence similarity if retrieval quality is low:
|
72 |
-
# e.g., 'dmis-lab/sapbert-from-pubmedbert-sentencetransformer'
|
73 |
-
EMBEDDING_MODEL_NAME = "dmis-lab/biobert-v1.1"
|
74 |
-
CHROMA_DISTANCE_METRIC = "cosine" # Cosine is generally good for sentence embeddings
|
75 |
-
|
76 |
-
# --- Caching Resource Initialization ---
|
77 |
-
|
78 |
-
@st.cache_resource
|
79 |
-
def initialize_gemini_model() -> Optional[genai.GenerativeModel]:
|
80 |
-
"""Initializes and returns the Gemini Generative Model."""
|
81 |
-
try:
|
82 |
-
genai.configure(api_key=GOOGLE_API_KEY)
|
83 |
-
model = genai.GenerativeModel(
|
84 |
-
model_name=VISION_MODEL_NAME,
|
85 |
-
generation_config=GENERATION_CONFIG,
|
86 |
-
safety_settings=SAFETY_SETTINGS
|
87 |
-
)
|
88 |
-
logger.info(f"Successfully initialized Gemini Model: {VISION_MODEL_NAME}")
|
89 |
-
return model
|
90 |
-
except Exception as e:
|
91 |
-
err_msg = f"β Error initializing Gemini Model ({VISION_MODEL_NAME}): {e}"
|
92 |
-
st.error(err_msg) # Safe to call st.error here now
|
93 |
-
logger.error(err_msg, exc_info=True)
|
94 |
-
return None
|
95 |
-
|
96 |
-
@st.cache_resource
|
97 |
-
def initialize_embedding_function() -> Optional[embedding_functions.HuggingFaceEmbeddingFunction]:
|
98 |
-
"""Initializes and returns the Hugging Face Embedding Function."""
|
99 |
-
st.info(f"Initializing Embedding Model: {EMBEDDING_MODEL_NAME} (this may take a moment)...")
|
100 |
-
try:
|
101 |
-
# Pass HF_TOKEN if it exists (required for private/gated models)
|
102 |
-
embed_func = embedding_functions.HuggingFaceEmbeddingFunction(
|
103 |
-
api_key=HF_TOKEN, # Pass token here if needed by model
|
104 |
-
model_name=EMBEDDING_MODEL_NAME
|
105 |
-
)
|
106 |
-
logger.info(f"Successfully initialized HuggingFace Embedding Function: {EMBEDDING_MODEL_NAME}")
|
107 |
-
st.success(f"Embedding Model {EMBEDDING_MODEL_NAME} initialized.")
|
108 |
-
return embed_func
|
109 |
-
except Exception as e:
|
110 |
-
err_msg = f"β Error initializing HuggingFace Embedding Function ({EMBEDDING_MODEL_NAME}): {e}"
|
111 |
-
st.error(err_msg) # Safe here
|
112 |
-
logger.error(err_msg, exc_info=True)
|
113 |
-
st.info("βΉοΈ Make sure the embedding model name is correct and you have network access. "
|
114 |
-
"If using a private model, ensure HF_TOKEN is set in secrets. Check Space logs for details.")
|
115 |
-
return None
|
116 |
-
|
117 |
-
@st.cache_resource
|
118 |
-
def initialize_chroma_collection(_embedding_func: embedding_functions.EmbeddingFunction) -> Optional[chromadb.Collection]:
|
119 |
-
"""Initializes the Chroma DB client and returns the collection."""
|
120 |
-
if not _embedding_func:
|
121 |
-
st.error("β Cannot initialize Chroma DB without a valid embedding function.") # Safe here
|
122 |
-
return None
|
123 |
-
st.info(f"Initializing Chroma DB collection '{COLLECTION_NAME}'...")
|
124 |
-
try:
|
125 |
-
chroma_client = chromadb.PersistentClient(path=CHROMA_PATH)
|
126 |
-
collection = chroma_client.get_or_create_collection(
|
127 |
-
name=COLLECTION_NAME,
|
128 |
-
embedding_function=_embedding_func, # Pass the initialized function
|
129 |
-
metadata={"hnsw:space": CHROMA_DISTANCE_METRIC}
|
130 |
-
)
|
131 |
-
logger.info(f"Chroma DB collection '{COLLECTION_NAME}' loaded/created at '{CHROMA_PATH}' using {CHROMA_DISTANCE_METRIC}.")
|
132 |
-
st.success(f"Chroma DB collection '{COLLECTION_NAME}' ready.")
|
133 |
-
return collection
|
134 |
-
except Exception as e:
|
135 |
-
err_msg = f"β Error initializing Chroma DB at '{CHROMA_PATH}': {e}"
|
136 |
-
st.error(err_msg) # Safe here
|
137 |
-
logger.error(err_msg, exc_info=True)
|
138 |
-
st.info(f"βΉοΈ Ensure the path '{CHROMA_PATH}' is writable. Check Space logs.")
|
139 |
-
return None
|
140 |
-
|
141 |
-
# --- Core Logic Functions (with Caching for Data Operations) ---
|
142 |
-
|
143 |
-
@st.cache_data(show_spinner=False) # Show spinner manually in UI
|
144 |
-
def analyze_image_with_gemini(_gemini_model: genai.GenerativeModel, image_bytes: bytes) -> Tuple[str, bool]:
|
145 |
-
"""
|
146 |
-
Analyzes image bytes with Gemini, returns (analysis_text, is_error).
|
147 |
-
Uses Streamlit's caching based on image_bytes.
|
148 |
-
"""
|
149 |
-
if not _gemini_model:
|
150 |
-
return "Error: Gemini model not initialized.", True
|
151 |
-
|
152 |
-
try:
|
153 |
-
img = Image.open(io.BytesIO(image_bytes))
|
154 |
-
response = _gemini_model.generate_content([GEMINI_ANALYSIS_PROMPT, img])
|
155 |
-
|
156 |
-
if not response.parts:
|
157 |
-
if response.prompt_feedback and response.prompt_feedback.block_reason:
|
158 |
-
reason = response.prompt_feedback.block_reason
|
159 |
-
msg = f"Analysis blocked by safety settings: {reason}"
|
160 |
-
logger.warning(msg)
|
161 |
-
return msg, True # Indicate block/error state
|
162 |
-
else:
|
163 |
-
msg = "Error: Gemini analysis returned no content (empty or invalid response)."
|
164 |
-
logger.error(msg)
|
165 |
-
return msg, True
|
166 |
-
logger.info("Gemini analysis successful.")
|
167 |
-
return response.text, False # Indicate success
|
168 |
-
|
169 |
-
except genai.types.BlockedPromptException as e:
|
170 |
-
msg = f"Analysis blocked (prompt issue): {e}"
|
171 |
-
logger.warning(msg)
|
172 |
-
return msg, True
|
173 |
-
except Exception as e:
|
174 |
-
msg = f"Error during Gemini analysis: {e}"
|
175 |
-
logger.error(msg, exc_info=True)
|
176 |
-
return msg, True
|
177 |
-
|
178 |
-
@st.cache_data(show_spinner=False)
|
179 |
-
def query_chroma(_collection: chromadb.Collection, query_text: str, n_results: int = 5) -> Optional[Dict[str, List[Any]]]:
|
180 |
-
"""Queries Chroma DB, returns results dict or None on error."""
|
181 |
-
if not _collection:
|
182 |
-
logger.error("Query attempt failed: Chroma collection is not available.")
|
183 |
-
return None
|
184 |
-
if not query_text:
|
185 |
-
logger.warning("Attempted to query Chroma with empty text.")
|
186 |
-
return None
|
187 |
-
try:
|
188 |
-
refined_query = query_text # Using direct analysis text for now
|
189 |
-
|
190 |
-
results = _collection.query(
|
191 |
-
query_texts=[refined_query],
|
192 |
-
n_results=n_results,
|
193 |
-
include=['documents', 'metadatas', 'distances']
|
194 |
-
)
|
195 |
-
logger.info(f"Chroma query successful for text snippet: '{query_text[:50]}...'")
|
196 |
-
return results
|
197 |
-
except Exception as e:
|
198 |
-
# Show error in UI as well
|
199 |
-
st.error(f"β Error querying Chroma DB: {e}", icon="π¨")
|
200 |
-
logger.error(f"Error querying Chroma DB: {e}", exc_info=True)
|
201 |
-
return None
|
202 |
-
|
203 |
-
def add_dummy_data_to_chroma(collection: chromadb.Collection, embedding_func: embedding_functions.EmbeddingFunction):
|
204 |
-
"""Adds example medical text snippets to Chroma using the provided embedding function."""
|
205 |
-
if not collection or not embedding_func:
|
206 |
-
st.error("β Cannot add dummy data: Chroma Collection or Embedding Function not available.")
|
207 |
-
return
|
208 |
-
|
209 |
-
# Check if dummy data needs adding first to avoid unnecessary processing
|
210 |
-
docs_to_check = [
|
211 |
-
"Figure 1A shows adenocarcinoma of the lung, papillary subtype. Note the glandular structures and nuclear atypia. TTF-1 staining was positive."
|
212 |
-
] # Only check one doc for speed
|
213 |
-
try:
|
214 |
-
existing_check = collection.get(where={"document": docs_to_check[0]}, limit=1, include=[])
|
215 |
-
if existing_check and existing_check.get('ids'):
|
216 |
-
st.info("Dummy data seems to already exist. Skipping add.")
|
217 |
-
logger.info("Skipping dummy data addition as it likely exists.")
|
218 |
-
return
|
219 |
-
except Exception as e:
|
220 |
-
logger.warning(f"Could not efficiently check for existing dummy data: {e}. Proceeding with add attempt.")
|
221 |
-
|
222 |
-
|
223 |
-
status = st.status(f"Adding dummy data (using {EMBEDDING_MODEL_NAME})...", expanded=True)
|
224 |
-
try:
|
225 |
-
# --- Dummy Data Definition ---
|
226 |
-
docs = [
|
227 |
-
"Figure 1A shows adenocarcinoma of the lung, papillary subtype. Note the glandular structures and nuclear atypia. TTF-1 staining was positive.",
|
228 |
-
"Pathology slide 34B demonstrates high-grade glioma (glioblastoma) with significant necrosis and microvascular proliferation. Ki-67 index was high.",
|
229 |
-
"This diagram illustrates the EGFR signaling pathway and common mutation sites targeted by tyrosine kinase inhibitors in non-small cell lung cancer.",
|
230 |
-
"Micrograph showing chronic gastritis with Helicobacter pylori organisms (visible with special stain, not shown here). Mild intestinal metaplasia is present.",
|
231 |
-
"Slide CJD-Sample-02: Spongiform changes characteristic of prion disease are evident in the cerebral cortex. Gliosis is also noted."
|
232 |
-
]
|
233 |
-
metadatas = [
|
234 |
-
{"source": "Example Paper 1", "topic": "Lung Cancer Pathology", "entities": "adenocarcinoma, lung cancer, glandular structures, nuclear atypia, papillary subtype, TTF-1", "IMAGE_ID": "fig_1a_adeno_lung.png"},
|
235 |
-
{"source": "Path Report 789", "topic": "Brain Tumor Pathology", "entities": "high-grade glioma, glioblastoma, necrosis, microvascular proliferation, Ki-67", "IMAGE_ID": "slide_34b_gbm.tiff"},
|
236 |
-
{"source": "Textbook Chapter 5", "topic": "Molecular Oncology Pathways", "entities": "EGFR, tyrosine kinase inhibitors, non-small cell lung cancer", "IMAGE_ID": "diagram_egfr_pathway.svg"},
|
237 |
-
{"source": "Path Report 101", "topic": "Gastrointestinal Pathology", "entities": "chronic gastritis, Helicobacter pylori, intestinal metaplasia", "IMAGE_ID": "micrograph_h_pylori_gastritis.jpg"},
|
238 |
-
{"source": "Case Study CJD", "topic": "Neuropathology", "entities": "prion disease, Spongiform changes, Gliosis, cerebral cortex", "IMAGE_ID": "slide_cjd_sample_02.jpg"}
|
239 |
-
]
|
240 |
-
# Ensure IDs are unique even if run close together
|
241 |
-
base_id = f"doc_biobert_{int(time.time() * 1000)}"
|
242 |
-
ids = [f"{base_id}_{i}" for i in range(len(docs))]
|
243 |
-
|
244 |
-
status.update(label=f"Generating embeddings & adding {len(docs)} documents (this uses BioBERT and may take time)...")
|
245 |
-
|
246 |
-
# Embeddings are generated implicitly by ChromaDB during .add()
|
247 |
-
collection.add(
|
248 |
-
documents=docs,
|
249 |
-
metadatas=metadatas,
|
250 |
-
ids=ids
|
251 |
-
)
|
252 |
-
status.update(label=f"β
Added {len(docs)} dummy documents.", state="complete", expanded=False)
|
253 |
-
logger.info(f"Added {len(docs)} dummy documents to collection '{COLLECTION_NAME}'.")
|
254 |
-
|
255 |
-
except Exception as e:
|
256 |
-
err_msg = f"Error adding dummy data to Chroma: {e}"
|
257 |
-
status.update(label=f"β Error: {err_msg}", state="error", expanded=True)
|
258 |
-
logger.error(err_msg, exc_info=True)
|
259 |
-
|
260 |
-
# --- Initialize Resources ---
|
261 |
-
# These calls use @st.cache_resource, run only once unless cleared/changed.
|
262 |
-
# Order matters if one depends on another (embedding func needed for chroma).
|
263 |
-
gemini_model = initialize_gemini_model()
|
264 |
-
embedding_func = initialize_embedding_function()
|
265 |
-
collection = initialize_chroma_collection(embedding_func) # Pass embedding func
|
266 |
-
|
267 |
-
# --- Streamlit UI ---
|
268 |
-
# set_page_config() is already called at the top
|
269 |
-
|
270 |
-
st.title("βοΈ Medical Image Analysis & RAG (BioBERT Embeddings)")
|
271 |
-
|
272 |
-
# --- DISCLAIMER ---
|
273 |
-
st.warning("""
|
274 |
-
**β οΈ Disclaimer:** This tool is for demonstration and informational purposes ONLY.
|
275 |
-
It is **NOT** a medical device and should **NOT** be used for actual medical diagnosis, treatment, or decision-making.
|
276 |
-
AI analysis can be imperfect. Always consult with qualified healthcare professionals for any medical concerns.
|
277 |
-
Do **NOT** upload identifiable patient data (PHI). Analysis quality depends heavily on the chosen embedding model.
|
278 |
-
""", icon="β£οΈ")
|
279 |
-
|
280 |
-
st.markdown(f"""
|
281 |
-
Upload a medical image. Gemini Vision will analyze it. Related information
|
282 |
-
will be retrieved from a Chroma DB knowledge base using **{EMBEDDING_MODEL_NAME}** embeddings.
|
283 |
-
""")
|
284 |
-
|
285 |
-
# Sidebar
|
286 |
-
with st.sidebar:
|
287 |
-
st.header("βοΈ Controls")
|
288 |
-
uploaded_file = st.file_uploader(
|
289 |
-
"Choose an image...",
|
290 |
-
type=["jpg", "jpeg", "png", "tiff", "webp"],
|
291 |
-
help="Upload a medical image file (e.g., pathology, diagram)."
|
292 |
)
|
|
|
293 |
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
if collection and embedding_func:
|
298 |
-
add_dummy_data_to_chroma(collection, embedding_func)
|
299 |
-
else:
|
300 |
-
st.error("β Cannot add dummy data: Chroma Collection or Embedding Function failed to initialize.")
|
301 |
-
|
302 |
-
st.divider()
|
303 |
-
|
304 |
-
st.header("βΉοΈ System Info")
|
305 |
-
st.caption(f"**Gemini Model:** `{VISION_MODEL_NAME}`")
|
306 |
-
st.caption(f"**Embedding Model:** `{EMBEDDING_MODEL_NAME}`")
|
307 |
-
st.caption(f"**Chroma Collection:** `{COLLECTION_NAME}`")
|
308 |
-
st.caption(f"**Chroma Path:** `{CHROMA_PATH}`")
|
309 |
-
st.caption(f"**Distance Metric:** `{CHROMA_DISTANCE_METRIC}`")
|
310 |
-
st.caption(f"**Google API Key:** {'Set' if GOOGLE_API_KEY else 'Not Set'}")
|
311 |
-
st.caption(f"**HF Token:** {'Provided' if HF_TOKEN else 'Not Provided'}")
|
312 |
-
|
313 |
-
# Main Display Area
|
314 |
-
col1, col2 = st.columns(2)
|
315 |
-
|
316 |
-
with col1:
|
317 |
-
st.subheader("πΌοΈ Uploaded Image")
|
318 |
-
if uploaded_file is not None:
|
319 |
-
image_bytes = uploaded_file.getvalue()
|
320 |
-
st.image(image_bytes, caption=f"Uploaded: {uploaded_file.name}", use_column_width=True)
|
321 |
else:
|
322 |
-
st.
|
323 |
-
|
324 |
-
with
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
st.
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
st.info("Analysis results will appear here once an image is uploaded.")
|
380 |
-
else:
|
381 |
-
# Initialization error occurred earlier, resources might be None
|
382 |
-
st.error("β Analysis cannot proceed. Check if Gemini model or Chroma DB failed to initialize (see sidebar info & Space logs).")
|
383 |
-
|
384 |
-
|
385 |
-
st.markdown("---")
|
386 |
-
st.markdown("<div style='text-align: center; font-size: small;'>Powered by Google Gemini, Chroma DB, Hugging Face, and Streamlit</div>", unsafe_allow_html=True)
|
387 |
-
|
388 |
-
|
|
|
1 |
+
# app.py
|
2 |
+
import os
|
3 |
import streamlit as st
|
4 |
+
from streamlit_drawable_canvas import st_canvas
|
|
|
|
|
5 |
from PIL import Image
|
6 |
+
import openai
|
7 |
+
from io import BytesIO
|
8 |
+
import json
|
9 |
+
|
10 |
+
# βββ 1. Configuration & Secrets βββββββββββββββββββββββββββββββββββββββββββββ
|
11 |
+
openai.api_key = st.secrets["OPENAI_API_KEY"] # or os.getenv("OPENAI_API_KEY")
|
12 |
+
st.set_page_config(
|
13 |
+
page_title="MedSketchβ―AI",
|
14 |
+
layout="wide",
|
15 |
+
initial_sidebar_state="expanded",
|
16 |
+
)
|
17 |
+
|
18 |
+
# βββ 2. Sidebar: Settings & Metadata ββββββββββββββββββββββββββββββββββββββββ
|
19 |
+
st.sidebar.header("βοΈ Settings")
|
20 |
+
model_choice = st.sidebar.selectbox(
|
21 |
+
"Model",
|
22 |
+
["GPT-4o (API)", "Stable Diffusion LoRA"],
|
23 |
+
index=0
|
24 |
+
)
|
25 |
+
style_preset = st.sidebar.radio(
|
26 |
+
"Preset Style",
|
27 |
+
["Anatomical Diagram", "H&E Histology", "IHC Pathology", "Custom"]
|
28 |
+
)
|
29 |
+
strength = st.sidebar.slider("Stylization Strength", 0.1, 1.0, 0.7)
|
30 |
+
|
31 |
+
st.sidebar.markdown("---")
|
32 |
+
st.sidebar.header("π Metadata")
|
33 |
+
patient_id = st.sidebar.text_input("Patient / Case ID")
|
34 |
+
roi = st.sidebar.text_input("Region of Interest")
|
35 |
+
umls_code = st.sidebar.text_input("UMLS / SNOMED CT Code")
|
36 |
+
|
37 |
+
# βββ 3. Main: Prompt Input & Batch Generation βββββββββββββββββββββββββββββββ
|
38 |
+
st.title("πΌοΈ MedSketchβ―AI β Advanced Clinical Diagram Generator")
|
39 |
+
|
40 |
+
with st.expander("π Enter Prompts (one per line for batch)"):
|
41 |
+
raw = st.text_area(
|
42 |
+
"Describe what you need:",
|
43 |
+
placeholder=(
|
44 |
+
"e.g. βGenerate a labeled crossβsection of the human heart with chamber names, valves, and flow arrowsβ¦β\n"
|
45 |
+
"e.g. βProduce a stylized H&E stain of liver tissue highlighting portal triadsβ¦β"
|
46 |
+
),
|
47 |
+
height=120
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
)
|
49 |
+
prompts = [p.strip() for p in raw.splitlines() if p.strip()]
|
50 |
|
51 |
+
if st.button("π Generate"):
|
52 |
+
if not prompts:
|
53 |
+
st.error("Please enter at least one prompt.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
else:
|
55 |
+
cols = st.columns(min(3, len(prompts)))
|
56 |
+
for i, prompt in enumerate(prompts):
|
57 |
+
with st.spinner(f"Rendering image {i+1}/{len(prompts)}β¦"):
|
58 |
+
if model_choice == "GPT-4o (API)":
|
59 |
+
resp = openai.Image.create(
|
60 |
+
model="gpt-4o",
|
61 |
+
prompt=f"[{style_preset} | strength={strength}] {prompt}",
|
62 |
+
size="1024x1024"
|
63 |
+
)
|
64 |
+
img_data = requests.get(resp["data"][0]["url"]).content
|
65 |
+
else:
|
66 |
+
# stub for Stable Diffusion LoRA
|
67 |
+
img_data = generate_sd_image(prompt, style=style_preset, strength=strength)
|
68 |
+
img = Image.open(BytesIO(img_data))
|
69 |
+
|
70 |
+
# Display + Download
|
71 |
+
with cols[i]:
|
72 |
+
st.image(img, use_column_width=True, caption=prompt)
|
73 |
+
buf = BytesIO()
|
74 |
+
img.save(buf, format="PNG")
|
75 |
+
st.download_button(
|
76 |
+
label="β¬οΈ Download PNG",
|
77 |
+
data=buf.getvalue(),
|
78 |
+
file_name=f"medsketch_{i+1}.png",
|
79 |
+
mime="image/png"
|
80 |
+
)
|
81 |
+
|
82 |
+
# βββ Annotation Canvas βββββββββββββββββββββββββββ
|
83 |
+
st.markdown("**βοΈ Annotate:**")
|
84 |
+
canvas_res = st_canvas(
|
85 |
+
fill_color="rgba(255, 0, 0, 0.3)", # annotation color
|
86 |
+
stroke_width=2,
|
87 |
+
background_image=img,
|
88 |
+
update_streamlit=True,
|
89 |
+
height=512,
|
90 |
+
width=512,
|
91 |
+
drawing_mode="freedraw",
|
92 |
+
key=f"canvas_{i}"
|
93 |
+
)
|
94 |
+
# Save annotations
|
95 |
+
if canvas_res.json_data:
|
96 |
+
ann = canvas_res.json_data["objects"]
|
97 |
+
st.session_state.setdefault("annotations", {})[prompt] = ann
|
98 |
+
|
99 |
+
# βββ 4. History & Exports βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
100 |
+
if "annotations" in st.session_state:
|
101 |
+
st.markdown("---")
|
102 |
+
st.subheader("π Session History & Annotations")
|
103 |
+
for prm, objs in st.session_state["annotations"].items():
|
104 |
+
st.markdown(f"**Prompt:** {prm}")
|
105 |
+
st.json(objs)
|
106 |
+
st.download_button(
|
107 |
+
"β¬οΈ Export All Annotations (JSON)",
|
108 |
+
data=json.dumps(st.session_state["annotations"], indent=2),
|
109 |
+
file_name="medsketch_annotations.json",
|
110 |
+
mime="application/json"
|
111 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|