walaa2022 commited on
Commit
2c0541f
Β·
verified Β·
1 Parent(s): 71446ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -149
app.py CHANGED
@@ -1,7 +1,7 @@
1
- # app.py - MedGemma with Fixed Authentication
2
  import gradio as gr
3
  import torch
4
- from transformers import AutoProcessor, AutoModelForImageTextToText
5
  from PIL import Image
6
  import os
7
  import logging
@@ -30,13 +30,14 @@ def authenticate_hf():
30
  # Model configuration
31
  MODEL_ID = "google/medgemma-4b-it"
32
 
33
- # Global variables for model and processor
34
  model = None
35
  processor = None
 
36
 
37
  def load_model():
38
- """Load model and processor with authentication"""
39
- global model, processor
40
 
41
  try:
42
  # First authenticate
@@ -45,33 +46,43 @@ def load_model():
45
  logger.error("❌ Authentication required for MedGemma")
46
  return False
47
 
48
- logger.info(f"Loading model: {MODEL_ID}")
49
-
50
- # Check if CUDA is available
51
- device = "cuda" if torch.cuda.is_available() else "cpu"
52
- logger.info(f"Using device: {device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- # Load processor first
55
- logger.info("Loading processor...")
56
  processor = AutoProcessor.from_pretrained(
57
- MODEL_ID,
58
  trust_remote_code=True,
59
  token=True
60
  )
61
- logger.info("βœ… Processor loaded successfully")
62
 
63
- # Load model with authentication
64
- logger.info("Loading model...")
65
  model = AutoModelForImageTextToText.from_pretrained(
66
  MODEL_ID,
67
- torch_dtype=torch.float32, # Use float32 for CPU compatibility
68
- device_map=None, # Let PyTorch handle device placement
69
  trust_remote_code=True,
70
- low_cpu_mem_usage=True,
71
  token=True
72
  )
73
  logger.info("βœ… Model loaded successfully!")
74
-
75
  return True
76
 
77
  except Exception as e:
@@ -85,19 +96,19 @@ model_loaded = load_model()
85
 
86
  def analyze_medical_image(image, clinical_question, patient_history=""):
87
  """Analyze medical image with clinical context"""
88
- global model, processor
89
 
90
  # Check if model is loaded
91
- if not model_loaded or model is None or processor is None:
92
  return """❌ **Model Loading Issue**
93
 
94
- The model failed to load properly. This could be due to:
95
 
96
- 1. **Memory constraints**: The model requires significant RAM
97
- 2. **Hardware limitations**: Consider upgrading to GPU hardware
98
- 3. **Temporary issue**: Try refreshing the page
99
 
100
- **Current Status**: Model loading failed - please try again or contact support."""
101
 
102
  if image is None:
103
  return "⚠️ Please upload a medical image first."
@@ -106,65 +117,72 @@ The model failed to load properly. This could be due to:
106
  return "⚠️ Please provide a clinical question."
107
 
108
  try:
109
- # Prepare the conversation
110
- messages = [
111
- {
112
- "role": "system",
113
- "content": [{"type": "text", "text": "You are MedGemma, an expert medical AI assistant specialized in medical image analysis. Provide detailed, structured analysis while emphasizing that this is for educational purposes only and should not replace professional medical diagnosis."}]
114
- }
115
- ]
116
-
117
- # Build user message content
118
- user_content = []
119
-
120
- # Add patient history if provided
121
- if patient_history.strip():
122
- user_content.append({"type": "text", "text": f"Patient History: {patient_history}\n\n"})
123
-
124
- # Add the clinical question
125
- user_content.append({"type": "text", "text": f"Clinical Question: {clinical_question}"})
126
-
127
- # Add the image
128
- user_content.append({"type": "image", "image": image})
129
-
130
- messages.append({
131
- "role": "user",
132
- "content": user_content
133
- })
134
-
135
- # Process inputs
136
- logger.info("Processing input...")
137
- inputs = processor.apply_chat_template(
138
- messages,
139
- add_generation_prompt=True,
140
- tokenize=True,
141
- return_dict=True,
142
- return_tensors="pt"
143
- )
144
-
145
- # Move to appropriate device if model is on GPU
146
- if torch.cuda.is_available() and next(model.parameters()).is_cuda:
147
- device = next(model.parameters()).device
148
- inputs = {k: v.to(device) for k, v in inputs.items()}
149
-
150
- input_len = inputs["input_ids"].shape[-1]
151
-
152
- # Generate response
153
- logger.info("Generating response...")
154
- with torch.inference_mode():
155
- generation = model.generate(
156
- **inputs,
157
- max_new_tokens=1000, # Reduced for stability
158
- do_sample=True,
159
- temperature=0.3,
160
- top_p=0.95,
161
- repetition_penalty=1.1,
162
- pad_token_id=processor.tokenizer.eos_token_id if hasattr(processor, 'tokenizer') else None
163
  )
164
- generation = generation[0][input_len:]
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
- # Decode response
167
- response = processor.decode(generation, skip_special_tokens=True)
 
 
168
  response = response.strip()
169
 
170
  # Add medical disclaimer
@@ -187,31 +205,18 @@ The model failed to load properly. This could be due to:
187
  logger.error(f"❌ Error in analysis: {str(e)}")
188
  import traceback
189
  logger.error(f"Full traceback: {traceback.format_exc()}")
190
- return f"❌ Analysis failed: {str(e)}\n\nPlease try with a different image or question."
191
 
192
  # Create Gradio interface
193
  def create_interface():
194
  with gr.Blocks(
195
- title="MedGemma Medical Image Analysis",
196
  theme=gr.themes.Soft(),
197
  css="""
198
- .gradio-container {
199
- max-width: 1200px !important;
200
- }
201
- .disclaimer {
202
- background-color: #fef2f2;
203
- border: 1px solid #fecaca;
204
- border-radius: 8px;
205
- padding: 16px;
206
- margin: 16px 0;
207
- }
208
- .success {
209
- background-color: #f0f9ff;
210
- border: 1px solid #bae6fd;
211
- border-radius: 8px;
212
- padding: 16px;
213
- margin: 16px 0;
214
- }
215
  """
216
  ) as demo:
217
 
@@ -227,21 +232,22 @@ def create_interface():
227
 
228
  # Status display
229
  if model_loaded:
230
- gr.Markdown("""
 
231
  <div class="success">
232
- βœ… <strong>SYSTEM READY</strong><br>
233
- MedGemma model is loaded and authenticated. You can now analyze medical images.
234
  </div>
235
  """)
236
  else:
237
  gr.Markdown("""
238
- <div class="disclaimer">
239
- ⚠️ <strong>SYSTEM LOADING</strong><br>
240
- MedGemma model is still loading. Please wait a few moments and refresh the page.
241
  </div>
242
  """)
243
 
244
- # Warning banner
245
  gr.Markdown("""
246
  <div class="disclaimer">
247
  ⚠️ <strong>IMPORTANT MEDICAL DISCLAIMER</strong><br>
@@ -251,66 +257,64 @@ def create_interface():
251
  """)
252
 
253
  with gr.Row():
254
- # Left column - Inputs
255
  with gr.Column(scale=1):
256
- gr.Markdown("## πŸ“€ Upload Medical Image")
257
 
258
  image_input = gr.Image(
259
  label="Medical Image",
260
  type="pil",
261
- height=300,
262
- sources=["upload", "clipboard"]
263
  )
264
 
265
  clinical_question = gr.Textbox(
266
  label="Clinical Question *",
267
- placeholder="Examples:\nβ€’ Describe the findings in this chest X-ray\nβ€’ What pathological changes are visible?\nβ€’ Provide differential diagnosis\nβ€’ Identify any abnormalities",
268
  lines=4
269
  )
270
 
271
  patient_history = gr.Textbox(
272
  label="Patient History (Optional)",
273
- placeholder="e.g., 65-year-old male with chronic cough and dyspnea",
274
  lines=2
275
  )
276
 
277
  with gr.Row():
278
  clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
279
- analyze_btn = gr.Button("πŸ” Analyze Image", variant="primary", size="lg")
280
-
281
- # System status
282
- auth_status = "βœ… Authenticated" if model_loaded else "πŸ”„ Loading"
283
- model_status = "βœ… Ready" if model_loaded else "πŸ”„ Loading"
284
 
 
285
  gr.Markdown(f"""
286
- **Authentication:** {auth_status}
287
- **Model Status:** {model_status}
288
- **Device:** {'CUDA' if torch.cuda.is_available() else 'CPU'}
 
289
  """)
290
 
291
- # Right column - Output
292
  with gr.Column(scale=1):
293
- gr.Markdown("## πŸ“‹ Medical Analysis")
294
 
295
  output = gr.Textbox(
296
- label="AI Analysis Results",
297
  lines=20,
298
  show_copy_button=True,
299
- placeholder="Upload a medical image and ask a clinical question to get started..."
300
  )
301
 
302
- # Example cases
303
- with gr.Accordion("πŸ“š Example Cases", open=False):
304
- examples = gr.Examples(
305
- examples=[
306
- [
307
- "https://upload.wikimedia.org/wikipedia/commons/c/c8/Chest_Xray_PA_3-8-2010.png",
308
- "Analyze this chest X-ray for any abnormal findings. Comment on heart size, lung fields, and overall anatomy.",
309
- "Adult patient with respiratory symptoms"
310
- ]
311
- ],
312
- inputs=[image_input, clinical_question, patient_history]
313
- )
 
314
 
315
  # Event handlers
316
  analyze_btn.click(
@@ -320,11 +324,8 @@ def create_interface():
320
  show_progress=True
321
  )
322
 
323
- def clear_all():
324
- return None, "", "", ""
325
-
326
  clear_btn.click(
327
- fn=clear_all,
328
  outputs=[image_input, clinical_question, patient_history, output]
329
  )
330
 
@@ -333,13 +334,12 @@ def create_interface():
333
  ---
334
  ### πŸ”¬ About MedGemma
335
 
336
- MedGemma-4B is Google's specialized medical AI model for educational medical image analysis.
337
- It demonstrates strong performance across radiology, pathology, dermatology, and ophthalmology.
338
 
339
  ### πŸ”’ Privacy & Ethics
340
- - Real-time processing with no data retention
341
- - Designed for educational and research use only
342
- - No PHI or patient data should be uploaded
343
 
344
  **Model:** Google MedGemma-4B | **License:** Apache 2.0
345
  """)
 
1
+ # app.py - Working MedGemma with Correct Implementation
2
  import gradio as gr
3
  import torch
4
+ from transformers import AutoProcessor, AutoModelForImageTextToText, pipeline
5
  from PIL import Image
6
  import os
7
  import logging
 
30
  # Model configuration
31
  MODEL_ID = "google/medgemma-4b-it"
32
 
33
+ # Global variables
34
  model = None
35
  processor = None
36
+ pipeline_model = None
37
 
38
  def load_model():
39
+ """Load MedGemma model using the recommended approach"""
40
+ global model, processor, pipeline_model
41
 
42
  try:
43
  # First authenticate
 
46
  logger.error("❌ Authentication required for MedGemma")
47
  return False
48
 
49
+ logger.info(f"Loading MedGemma: {MODEL_ID}")
50
+
51
+ # Method 1: Try using pipeline (recommended by HuggingFace)
52
+ try:
53
+ logger.info("Attempting to load using pipeline...")
54
+ pipeline_model = pipeline(
55
+ "image-text-to-text",
56
+ model=MODEL_ID,
57
+ torch_dtype=torch.float32,
58
+ device_map="auto" if torch.cuda.is_available() else None,
59
+ trust_remote_code=True
60
+ )
61
+ logger.info("βœ… Pipeline model loaded successfully!")
62
+ return True
63
+ except Exception as e:
64
+ logger.warning(f"Pipeline loading failed: {e}")
65
+
66
+ # Method 2: Try direct model loading
67
+ logger.info("Attempting direct model loading...")
68
 
69
+ # Load processor
 
70
  processor = AutoProcessor.from_pretrained(
71
+ MODEL_ID,
72
  trust_remote_code=True,
73
  token=True
74
  )
75
+ logger.info("βœ… Processor loaded")
76
 
77
+ # Load model
 
78
  model = AutoModelForImageTextToText.from_pretrained(
79
  MODEL_ID,
80
+ torch_dtype=torch.float32,
81
+ device_map="auto" if torch.cuda.is_available() else None,
82
  trust_remote_code=True,
 
83
  token=True
84
  )
85
  logger.info("βœ… Model loaded successfully!")
 
86
  return True
87
 
88
  except Exception as e:
 
96
 
97
  def analyze_medical_image(image, clinical_question, patient_history=""):
98
  """Analyze medical image with clinical context"""
99
+ global model, processor, pipeline_model
100
 
101
  # Check if model is loaded
102
+ if not model_loaded:
103
  return """❌ **Model Loading Issue**
104
 
105
+ MedGemma failed to load. This is likely due to:
106
 
107
+ 1. **Transformers version**: Make sure you're using transformers >= 4.52.0
108
+ 2. **Authentication**: Ensure HF_TOKEN is properly set
109
+ 3. **Model compatibility**: MedGemma requires the latest transformers library
110
 
111
+ **Status**: Model loading failed. Please try refreshing the page or contact support."""
112
 
113
  if image is None:
114
  return "⚠️ Please upload a medical image first."
 
117
  return "⚠️ Please provide a clinical question."
118
 
119
  try:
120
+ # Method 1: Use pipeline if available
121
+ if pipeline_model is not None:
122
+ logger.info("Using pipeline for analysis...")
123
+
124
+ # Prepare message in the format expected by pipeline
125
+ messages = [
126
+ {
127
+ "role": "user",
128
+ "content": [
129
+ {"type": "image", "image": image},
130
+ {"type": "text", "text": f"Patient History: {patient_history}\n\nClinical Question: {clinical_question}\n\nAs MedGemma, provide a detailed medical analysis of this image for educational purposes only."}
131
+ ]
132
+ }
133
+ ]
134
+
135
+ # Generate response using pipeline
136
+ result = pipeline_model(messages, max_new_tokens=1000)
137
+
138
+ # Extract response text
139
+ response = result[0]['generated_text'] if isinstance(result, list) else result['generated_text']
140
+
141
+ # Method 2: Use direct model if pipeline failed
142
+ elif model is not None and processor is not None:
143
+ logger.info("Using direct model for analysis...")
144
+
145
+ # Prepare messages for direct model
146
+ messages = [
147
+ {
148
+ "role": "system",
149
+ "content": [{"type": "text", "text": "You are MedGemma, an expert medical AI assistant. Provide detailed medical analysis for educational purposes only."}]
150
+ },
151
+ {
152
+ "role": "user",
153
+ "content": [
154
+ {"type": "text", "text": f"Patient History: {patient_history}\n\nClinical Question: {clinical_question}"},
155
+ {"type": "image", "image": image}
156
+ ]
157
+ }
158
+ ]
159
+
160
+ # Process inputs
161
+ inputs = processor.apply_chat_template(
162
+ messages,
163
+ add_generation_prompt=True,
164
+ tokenize=True,
165
+ return_dict=True,
166
+ return_tensors="pt"
 
 
 
 
 
 
 
167
  )
168
+
169
+ # Generate response
170
+ with torch.inference_mode():
171
+ outputs = model.generate(
172
+ **inputs,
173
+ max_new_tokens=1000,
174
+ do_sample=True,
175
+ temperature=0.3,
176
+ top_p=0.9
177
+ )
178
+
179
+ # Decode response
180
+ response = processor.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
181
 
182
+ else:
183
+ return "❌ No model available for analysis. Please try refreshing the page."
184
+
185
+ # Clean up response
186
  response = response.strip()
187
 
188
  # Add medical disclaimer
 
205
  logger.error(f"❌ Error in analysis: {str(e)}")
206
  import traceback
207
  logger.error(f"Full traceback: {traceback.format_exc()}")
208
+ return f"❌ Analysis failed: {str(e)}\n\nPlease try again with a different image or question."
209
 
210
  # Create Gradio interface
211
  def create_interface():
212
  with gr.Blocks(
213
+ title="MedGemma Medical Analysis",
214
  theme=gr.themes.Soft(),
215
  css="""
216
+ .gradio-container { max-width: 1200px !important; }
217
+ .disclaimer { background-color: #fef2f2; border: 1px solid #fecaca; border-radius: 8px; padding: 16px; margin: 16px 0; }
218
+ .success { background-color: #f0f9ff; border: 1px solid #bae6fd; border-radius: 8px; padding: 16px; margin: 16px 0; }
219
+ .warning { background-color: #fffbeb; border: 1px solid #fed7aa; border-radius: 8px; padding: 16px; margin: 16px 0; }
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  """
221
  ) as demo:
222
 
 
232
 
233
  # Status display
234
  if model_loaded:
235
+ method = "Pipeline" if pipeline_model else "Direct Model"
236
+ gr.Markdown(f"""
237
  <div class="success">
238
+ βœ… <strong>MEDGEMMA READY</strong><br>
239
+ Model loaded successfully using {method} method. Ready for medical image analysis.
240
  </div>
241
  """)
242
  else:
243
  gr.Markdown("""
244
+ <div class="warning">
245
+ ⚠️ <strong>MODEL LOADING FAILED</strong><br>
246
+ MedGemma failed to load. Please ensure you have the latest transformers library and proper authentication.
247
  </div>
248
  """)
249
 
250
+ # Medical disclaimer
251
  gr.Markdown("""
252
  <div class="disclaimer">
253
  ⚠️ <strong>IMPORTANT MEDICAL DISCLAIMER</strong><br>
 
257
  """)
258
 
259
  with gr.Row():
260
+ # Left column
261
  with gr.Column(scale=1):
262
+ gr.Markdown("## πŸ“€ Medical Image Upload")
263
 
264
  image_input = gr.Image(
265
  label="Medical Image",
266
  type="pil",
267
+ height=300
 
268
  )
269
 
270
  clinical_question = gr.Textbox(
271
  label="Clinical Question *",
272
+ placeholder="Examples:\nβ€’ Describe findings in this chest X-ray\nβ€’ What pathological changes are visible?\nβ€’ Provide differential diagnosis\nβ€’ Identify abnormalities",
273
  lines=4
274
  )
275
 
276
  patient_history = gr.Textbox(
277
  label="Patient History (Optional)",
278
+ placeholder="e.g., 65-year-old male with chronic cough",
279
  lines=2
280
  )
281
 
282
  with gr.Row():
283
  clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
284
+ analyze_btn = gr.Button("πŸ” Analyze", variant="primary", size="lg")
 
 
 
 
285
 
286
+ # System info
287
  gr.Markdown(f"""
288
+ **Status:** {'βœ… Ready' if model_loaded else '❌ Failed'}
289
+ **Method:** {'Pipeline' if pipeline_model else 'Direct' if model else 'None'}
290
+ **Device:** {'CUDA' if torch.cuda.is_available() else 'CPU'}
291
+ **Transformers:** {getattr(__import__('transformers'), '__version__', 'Unknown')}
292
  """)
293
 
294
+ # Right column
295
  with gr.Column(scale=1):
296
+ gr.Markdown("## πŸ“‹ Medical Analysis Results")
297
 
298
  output = gr.Textbox(
299
+ label="AI Medical Analysis",
300
  lines=20,
301
  show_copy_button=True,
302
+ placeholder="Upload a medical image and ask a clinical question..." if model_loaded else "Model unavailable - please check system status"
303
  )
304
 
305
+ # Examples
306
+ if model_loaded:
307
+ with gr.Accordion("πŸ“š Example Cases", open=False):
308
+ examples = gr.Examples(
309
+ examples=[
310
+ [
311
+ "https://upload.wikimedia.org/wikipedia/commons/c/c8/Chest_Xray_PA_3-8-2010.png",
312
+ "Analyze this chest X-ray systematically. Comment on heart size, lung fields, and any abnormalities.",
313
+ "Adult patient with respiratory symptoms"
314
+ ]
315
+ ],
316
+ inputs=[image_input, clinical_question, patient_history]
317
+ )
318
 
319
  # Event handlers
320
  analyze_btn.click(
 
324
  show_progress=True
325
  )
326
 
 
 
 
327
  clear_btn.click(
328
+ fn=lambda: (None, "", "", ""),
329
  outputs=[image_input, clinical_question, patient_history, output]
330
  )
331
 
 
334
  ---
335
  ### πŸ”¬ About MedGemma
336
 
337
+ MedGemma-4B is Google's specialized medical AI model requiring transformers >= 4.52.0.
 
338
 
339
  ### πŸ”’ Privacy & Ethics
340
+ - Real-time processing, no data storage
341
+ - Educational and research purposes only
342
+ - No patient data should be uploaded
343
 
344
  **Model:** Google MedGemma-4B | **License:** Apache 2.0
345
  """)