Update app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
-
# app.py - MedGemma with
|
2 |
import gradio as gr
|
3 |
import torch
|
4 |
-
from transformers import AutoProcessor, AutoModelForImageTextToText
|
5 |
from PIL import Image
|
6 |
import os
|
7 |
import logging
|
@@ -30,13 +30,14 @@ def authenticate_hf():
|
|
30 |
# Model configuration
|
31 |
MODEL_ID = "google/medgemma-4b-it"
|
32 |
|
33 |
-
# Global variables
|
34 |
model = None
|
35 |
processor = None
|
|
|
36 |
|
37 |
def load_model():
|
38 |
-
"""Load model
|
39 |
-
global model, processor
|
40 |
|
41 |
try:
|
42 |
# First authenticate
|
@@ -45,33 +46,43 @@ def load_model():
|
|
45 |
logger.error("β Authentication required for MedGemma")
|
46 |
return False
|
47 |
|
48 |
-
logger.info(f"Loading
|
49 |
-
|
50 |
-
#
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
-
# Load processor
|
55 |
-
logger.info("Loading processor...")
|
56 |
processor = AutoProcessor.from_pretrained(
|
57 |
-
MODEL_ID,
|
58 |
trust_remote_code=True,
|
59 |
token=True
|
60 |
)
|
61 |
-
logger.info("β
Processor loaded
|
62 |
|
63 |
-
# Load model
|
64 |
-
logger.info("Loading model...")
|
65 |
model = AutoModelForImageTextToText.from_pretrained(
|
66 |
MODEL_ID,
|
67 |
-
torch_dtype=torch.float32,
|
68 |
-
device_map=
|
69 |
trust_remote_code=True,
|
70 |
-
low_cpu_mem_usage=True,
|
71 |
token=True
|
72 |
)
|
73 |
logger.info("β
Model loaded successfully!")
|
74 |
-
|
75 |
return True
|
76 |
|
77 |
except Exception as e:
|
@@ -85,19 +96,19 @@ model_loaded = load_model()
|
|
85 |
|
86 |
def analyze_medical_image(image, clinical_question, patient_history=""):
|
87 |
"""Analyze medical image with clinical context"""
|
88 |
-
global model, processor
|
89 |
|
90 |
# Check if model is loaded
|
91 |
-
if not model_loaded
|
92 |
return """β **Model Loading Issue**
|
93 |
|
94 |
-
|
95 |
|
96 |
-
1. **
|
97 |
-
2. **
|
98 |
-
3. **
|
99 |
|
100 |
-
**
|
101 |
|
102 |
if image is None:
|
103 |
return "β οΈ Please upload a medical image first."
|
@@ -106,65 +117,72 @@ The model failed to load properly. This could be due to:
|
|
106 |
return "β οΈ Please provide a clinical question."
|
107 |
|
108 |
try:
|
109 |
-
#
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
"
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
**inputs,
|
157 |
-
max_new_tokens=1000, # Reduced for stability
|
158 |
-
do_sample=True,
|
159 |
-
temperature=0.3,
|
160 |
-
top_p=0.95,
|
161 |
-
repetition_penalty=1.1,
|
162 |
-
pad_token_id=processor.tokenizer.eos_token_id if hasattr(processor, 'tokenizer') else None
|
163 |
)
|
164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
|
166 |
-
|
167 |
-
|
|
|
|
|
168 |
response = response.strip()
|
169 |
|
170 |
# Add medical disclaimer
|
@@ -187,31 +205,18 @@ The model failed to load properly. This could be due to:
|
|
187 |
logger.error(f"β Error in analysis: {str(e)}")
|
188 |
import traceback
|
189 |
logger.error(f"Full traceback: {traceback.format_exc()}")
|
190 |
-
return f"β Analysis failed: {str(e)}\n\nPlease try with a different image or question."
|
191 |
|
192 |
# Create Gradio interface
|
193 |
def create_interface():
|
194 |
with gr.Blocks(
|
195 |
-
title="MedGemma Medical
|
196 |
theme=gr.themes.Soft(),
|
197 |
css="""
|
198 |
-
.gradio-container {
|
199 |
-
|
200 |
-
}
|
201 |
-
.
|
202 |
-
background-color: #fef2f2;
|
203 |
-
border: 1px solid #fecaca;
|
204 |
-
border-radius: 8px;
|
205 |
-
padding: 16px;
|
206 |
-
margin: 16px 0;
|
207 |
-
}
|
208 |
-
.success {
|
209 |
-
background-color: #f0f9ff;
|
210 |
-
border: 1px solid #bae6fd;
|
211 |
-
border-radius: 8px;
|
212 |
-
padding: 16px;
|
213 |
-
margin: 16px 0;
|
214 |
-
}
|
215 |
"""
|
216 |
) as demo:
|
217 |
|
@@ -227,21 +232,22 @@ def create_interface():
|
|
227 |
|
228 |
# Status display
|
229 |
if model_loaded:
|
230 |
-
|
|
|
231 |
<div class="success">
|
232 |
-
β
<strong>
|
233 |
-
|
234 |
</div>
|
235 |
""")
|
236 |
else:
|
237 |
gr.Markdown("""
|
238 |
-
<div class="
|
239 |
-
β οΈ <strong>
|
240 |
-
MedGemma
|
241 |
</div>
|
242 |
""")
|
243 |
|
244 |
-
#
|
245 |
gr.Markdown("""
|
246 |
<div class="disclaimer">
|
247 |
β οΈ <strong>IMPORTANT MEDICAL DISCLAIMER</strong><br>
|
@@ -251,66 +257,64 @@ def create_interface():
|
|
251 |
""")
|
252 |
|
253 |
with gr.Row():
|
254 |
-
# Left column
|
255 |
with gr.Column(scale=1):
|
256 |
-
gr.Markdown("## π€
|
257 |
|
258 |
image_input = gr.Image(
|
259 |
label="Medical Image",
|
260 |
type="pil",
|
261 |
-
height=300
|
262 |
-
sources=["upload", "clipboard"]
|
263 |
)
|
264 |
|
265 |
clinical_question = gr.Textbox(
|
266 |
label="Clinical Question *",
|
267 |
-
placeholder="Examples:\nβ’ Describe
|
268 |
lines=4
|
269 |
)
|
270 |
|
271 |
patient_history = gr.Textbox(
|
272 |
label="Patient History (Optional)",
|
273 |
-
placeholder="e.g., 65-year-old male with chronic cough
|
274 |
lines=2
|
275 |
)
|
276 |
|
277 |
with gr.Row():
|
278 |
clear_btn = gr.Button("ποΈ Clear", variant="secondary")
|
279 |
-
analyze_btn = gr.Button("π Analyze
|
280 |
-
|
281 |
-
# System status
|
282 |
-
auth_status = "β
Authenticated" if model_loaded else "π Loading"
|
283 |
-
model_status = "β
Ready" if model_loaded else "π Loading"
|
284 |
|
|
|
285 |
gr.Markdown(f"""
|
286 |
-
**
|
287 |
-
**
|
288 |
-
**Device:** {'CUDA' if torch.cuda.is_available() else 'CPU'}
|
|
|
289 |
""")
|
290 |
|
291 |
-
# Right column
|
292 |
with gr.Column(scale=1):
|
293 |
-
gr.Markdown("## π Medical Analysis")
|
294 |
|
295 |
output = gr.Textbox(
|
296 |
-
label="AI Analysis
|
297 |
lines=20,
|
298 |
show_copy_button=True,
|
299 |
-
placeholder="Upload a medical image and ask a clinical question
|
300 |
)
|
301 |
|
302 |
-
#
|
303 |
-
|
304 |
-
|
305 |
-
examples=
|
306 |
-
[
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
|
|
314 |
|
315 |
# Event handlers
|
316 |
analyze_btn.click(
|
@@ -320,11 +324,8 @@ def create_interface():
|
|
320 |
show_progress=True
|
321 |
)
|
322 |
|
323 |
-
def clear_all():
|
324 |
-
return None, "", "", ""
|
325 |
-
|
326 |
clear_btn.click(
|
327 |
-
fn=
|
328 |
outputs=[image_input, clinical_question, patient_history, output]
|
329 |
)
|
330 |
|
@@ -333,13 +334,12 @@ def create_interface():
|
|
333 |
---
|
334 |
### π¬ About MedGemma
|
335 |
|
336 |
-
MedGemma-4B is Google's specialized medical AI model
|
337 |
-
It demonstrates strong performance across radiology, pathology, dermatology, and ophthalmology.
|
338 |
|
339 |
### π Privacy & Ethics
|
340 |
-
- Real-time processing
|
341 |
-
-
|
342 |
-
- No
|
343 |
|
344 |
**Model:** Google MedGemma-4B | **License:** Apache 2.0
|
345 |
""")
|
|
|
1 |
+
# app.py - Working MedGemma with Correct Implementation
|
2 |
import gradio as gr
|
3 |
import torch
|
4 |
+
from transformers import AutoProcessor, AutoModelForImageTextToText, pipeline
|
5 |
from PIL import Image
|
6 |
import os
|
7 |
import logging
|
|
|
30 |
# Model configuration
|
31 |
MODEL_ID = "google/medgemma-4b-it"
|
32 |
|
33 |
+
# Global variables
|
34 |
model = None
|
35 |
processor = None
|
36 |
+
pipeline_model = None
|
37 |
|
38 |
def load_model():
|
39 |
+
"""Load MedGemma model using the recommended approach"""
|
40 |
+
global model, processor, pipeline_model
|
41 |
|
42 |
try:
|
43 |
# First authenticate
|
|
|
46 |
logger.error("β Authentication required for MedGemma")
|
47 |
return False
|
48 |
|
49 |
+
logger.info(f"Loading MedGemma: {MODEL_ID}")
|
50 |
+
|
51 |
+
# Method 1: Try using pipeline (recommended by HuggingFace)
|
52 |
+
try:
|
53 |
+
logger.info("Attempting to load using pipeline...")
|
54 |
+
pipeline_model = pipeline(
|
55 |
+
"image-text-to-text",
|
56 |
+
model=MODEL_ID,
|
57 |
+
torch_dtype=torch.float32,
|
58 |
+
device_map="auto" if torch.cuda.is_available() else None,
|
59 |
+
trust_remote_code=True
|
60 |
+
)
|
61 |
+
logger.info("β
Pipeline model loaded successfully!")
|
62 |
+
return True
|
63 |
+
except Exception as e:
|
64 |
+
logger.warning(f"Pipeline loading failed: {e}")
|
65 |
+
|
66 |
+
# Method 2: Try direct model loading
|
67 |
+
logger.info("Attempting direct model loading...")
|
68 |
|
69 |
+
# Load processor
|
|
|
70 |
processor = AutoProcessor.from_pretrained(
|
71 |
+
MODEL_ID,
|
72 |
trust_remote_code=True,
|
73 |
token=True
|
74 |
)
|
75 |
+
logger.info("β
Processor loaded")
|
76 |
|
77 |
+
# Load model
|
|
|
78 |
model = AutoModelForImageTextToText.from_pretrained(
|
79 |
MODEL_ID,
|
80 |
+
torch_dtype=torch.float32,
|
81 |
+
device_map="auto" if torch.cuda.is_available() else None,
|
82 |
trust_remote_code=True,
|
|
|
83 |
token=True
|
84 |
)
|
85 |
logger.info("β
Model loaded successfully!")
|
|
|
86 |
return True
|
87 |
|
88 |
except Exception as e:
|
|
|
96 |
|
97 |
def analyze_medical_image(image, clinical_question, patient_history=""):
|
98 |
"""Analyze medical image with clinical context"""
|
99 |
+
global model, processor, pipeline_model
|
100 |
|
101 |
# Check if model is loaded
|
102 |
+
if not model_loaded:
|
103 |
return """β **Model Loading Issue**
|
104 |
|
105 |
+
MedGemma failed to load. This is likely due to:
|
106 |
|
107 |
+
1. **Transformers version**: Make sure you're using transformers >= 4.52.0
|
108 |
+
2. **Authentication**: Ensure HF_TOKEN is properly set
|
109 |
+
3. **Model compatibility**: MedGemma requires the latest transformers library
|
110 |
|
111 |
+
**Status**: Model loading failed. Please try refreshing the page or contact support."""
|
112 |
|
113 |
if image is None:
|
114 |
return "β οΈ Please upload a medical image first."
|
|
|
117 |
return "β οΈ Please provide a clinical question."
|
118 |
|
119 |
try:
|
120 |
+
# Method 1: Use pipeline if available
|
121 |
+
if pipeline_model is not None:
|
122 |
+
logger.info("Using pipeline for analysis...")
|
123 |
+
|
124 |
+
# Prepare message in the format expected by pipeline
|
125 |
+
messages = [
|
126 |
+
{
|
127 |
+
"role": "user",
|
128 |
+
"content": [
|
129 |
+
{"type": "image", "image": image},
|
130 |
+
{"type": "text", "text": f"Patient History: {patient_history}\n\nClinical Question: {clinical_question}\n\nAs MedGemma, provide a detailed medical analysis of this image for educational purposes only."}
|
131 |
+
]
|
132 |
+
}
|
133 |
+
]
|
134 |
+
|
135 |
+
# Generate response using pipeline
|
136 |
+
result = pipeline_model(messages, max_new_tokens=1000)
|
137 |
+
|
138 |
+
# Extract response text
|
139 |
+
response = result[0]['generated_text'] if isinstance(result, list) else result['generated_text']
|
140 |
+
|
141 |
+
# Method 2: Use direct model if pipeline failed
|
142 |
+
elif model is not None and processor is not None:
|
143 |
+
logger.info("Using direct model for analysis...")
|
144 |
+
|
145 |
+
# Prepare messages for direct model
|
146 |
+
messages = [
|
147 |
+
{
|
148 |
+
"role": "system",
|
149 |
+
"content": [{"type": "text", "text": "You are MedGemma, an expert medical AI assistant. Provide detailed medical analysis for educational purposes only."}]
|
150 |
+
},
|
151 |
+
{
|
152 |
+
"role": "user",
|
153 |
+
"content": [
|
154 |
+
{"type": "text", "text": f"Patient History: {patient_history}\n\nClinical Question: {clinical_question}"},
|
155 |
+
{"type": "image", "image": image}
|
156 |
+
]
|
157 |
+
}
|
158 |
+
]
|
159 |
+
|
160 |
+
# Process inputs
|
161 |
+
inputs = processor.apply_chat_template(
|
162 |
+
messages,
|
163 |
+
add_generation_prompt=True,
|
164 |
+
tokenize=True,
|
165 |
+
return_dict=True,
|
166 |
+
return_tensors="pt"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
)
|
168 |
+
|
169 |
+
# Generate response
|
170 |
+
with torch.inference_mode():
|
171 |
+
outputs = model.generate(
|
172 |
+
**inputs,
|
173 |
+
max_new_tokens=1000,
|
174 |
+
do_sample=True,
|
175 |
+
temperature=0.3,
|
176 |
+
top_p=0.9
|
177 |
+
)
|
178 |
+
|
179 |
+
# Decode response
|
180 |
+
response = processor.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
|
181 |
|
182 |
+
else:
|
183 |
+
return "β No model available for analysis. Please try refreshing the page."
|
184 |
+
|
185 |
+
# Clean up response
|
186 |
response = response.strip()
|
187 |
|
188 |
# Add medical disclaimer
|
|
|
205 |
logger.error(f"β Error in analysis: {str(e)}")
|
206 |
import traceback
|
207 |
logger.error(f"Full traceback: {traceback.format_exc()}")
|
208 |
+
return f"β Analysis failed: {str(e)}\n\nPlease try again with a different image or question."
|
209 |
|
210 |
# Create Gradio interface
|
211 |
def create_interface():
|
212 |
with gr.Blocks(
|
213 |
+
title="MedGemma Medical Analysis",
|
214 |
theme=gr.themes.Soft(),
|
215 |
css="""
|
216 |
+
.gradio-container { max-width: 1200px !important; }
|
217 |
+
.disclaimer { background-color: #fef2f2; border: 1px solid #fecaca; border-radius: 8px; padding: 16px; margin: 16px 0; }
|
218 |
+
.success { background-color: #f0f9ff; border: 1px solid #bae6fd; border-radius: 8px; padding: 16px; margin: 16px 0; }
|
219 |
+
.warning { background-color: #fffbeb; border: 1px solid #fed7aa; border-radius: 8px; padding: 16px; margin: 16px 0; }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
"""
|
221 |
) as demo:
|
222 |
|
|
|
232 |
|
233 |
# Status display
|
234 |
if model_loaded:
|
235 |
+
method = "Pipeline" if pipeline_model else "Direct Model"
|
236 |
+
gr.Markdown(f"""
|
237 |
<div class="success">
|
238 |
+
β
<strong>MEDGEMMA READY</strong><br>
|
239 |
+
Model loaded successfully using {method} method. Ready for medical image analysis.
|
240 |
</div>
|
241 |
""")
|
242 |
else:
|
243 |
gr.Markdown("""
|
244 |
+
<div class="warning">
|
245 |
+
β οΈ <strong>MODEL LOADING FAILED</strong><br>
|
246 |
+
MedGemma failed to load. Please ensure you have the latest transformers library and proper authentication.
|
247 |
</div>
|
248 |
""")
|
249 |
|
250 |
+
# Medical disclaimer
|
251 |
gr.Markdown("""
|
252 |
<div class="disclaimer">
|
253 |
β οΈ <strong>IMPORTANT MEDICAL DISCLAIMER</strong><br>
|
|
|
257 |
""")
|
258 |
|
259 |
with gr.Row():
|
260 |
+
# Left column
|
261 |
with gr.Column(scale=1):
|
262 |
+
gr.Markdown("## π€ Medical Image Upload")
|
263 |
|
264 |
image_input = gr.Image(
|
265 |
label="Medical Image",
|
266 |
type="pil",
|
267 |
+
height=300
|
|
|
268 |
)
|
269 |
|
270 |
clinical_question = gr.Textbox(
|
271 |
label="Clinical Question *",
|
272 |
+
placeholder="Examples:\nβ’ Describe findings in this chest X-ray\nβ’ What pathological changes are visible?\nβ’ Provide differential diagnosis\nβ’ Identify abnormalities",
|
273 |
lines=4
|
274 |
)
|
275 |
|
276 |
patient_history = gr.Textbox(
|
277 |
label="Patient History (Optional)",
|
278 |
+
placeholder="e.g., 65-year-old male with chronic cough",
|
279 |
lines=2
|
280 |
)
|
281 |
|
282 |
with gr.Row():
|
283 |
clear_btn = gr.Button("ποΈ Clear", variant="secondary")
|
284 |
+
analyze_btn = gr.Button("π Analyze", variant="primary", size="lg")
|
|
|
|
|
|
|
|
|
285 |
|
286 |
+
# System info
|
287 |
gr.Markdown(f"""
|
288 |
+
**Status:** {'β
Ready' if model_loaded else 'β Failed'}
|
289 |
+
**Method:** {'Pipeline' if pipeline_model else 'Direct' if model else 'None'}
|
290 |
+
**Device:** {'CUDA' if torch.cuda.is_available() else 'CPU'}
|
291 |
+
**Transformers:** {getattr(__import__('transformers'), '__version__', 'Unknown')}
|
292 |
""")
|
293 |
|
294 |
+
# Right column
|
295 |
with gr.Column(scale=1):
|
296 |
+
gr.Markdown("## π Medical Analysis Results")
|
297 |
|
298 |
output = gr.Textbox(
|
299 |
+
label="AI Medical Analysis",
|
300 |
lines=20,
|
301 |
show_copy_button=True,
|
302 |
+
placeholder="Upload a medical image and ask a clinical question..." if model_loaded else "Model unavailable - please check system status"
|
303 |
)
|
304 |
|
305 |
+
# Examples
|
306 |
+
if model_loaded:
|
307 |
+
with gr.Accordion("π Example Cases", open=False):
|
308 |
+
examples = gr.Examples(
|
309 |
+
examples=[
|
310 |
+
[
|
311 |
+
"https://upload.wikimedia.org/wikipedia/commons/c/c8/Chest_Xray_PA_3-8-2010.png",
|
312 |
+
"Analyze this chest X-ray systematically. Comment on heart size, lung fields, and any abnormalities.",
|
313 |
+
"Adult patient with respiratory symptoms"
|
314 |
+
]
|
315 |
+
],
|
316 |
+
inputs=[image_input, clinical_question, patient_history]
|
317 |
+
)
|
318 |
|
319 |
# Event handlers
|
320 |
analyze_btn.click(
|
|
|
324 |
show_progress=True
|
325 |
)
|
326 |
|
|
|
|
|
|
|
327 |
clear_btn.click(
|
328 |
+
fn=lambda: (None, "", "", ""),
|
329 |
outputs=[image_input, clinical_question, patient_history, output]
|
330 |
)
|
331 |
|
|
|
334 |
---
|
335 |
### π¬ About MedGemma
|
336 |
|
337 |
+
MedGemma-4B is Google's specialized medical AI model requiring transformers >= 4.52.0.
|
|
|
338 |
|
339 |
### π Privacy & Ethics
|
340 |
+
- Real-time processing, no data storage
|
341 |
+
- Educational and research purposes only
|
342 |
+
- No patient data should be uploaded
|
343 |
|
344 |
**Model:** Google MedGemma-4B | **License:** Apache 2.0
|
345 |
""")
|