walaa2022 commited on
Commit
644aa62
·
verified ·
1 Parent(s): 815f7ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +268 -433
app.py CHANGED
@@ -1,461 +1,296 @@
1
- # app.py - Fixed MedGemma Implementation Based on Google's Official Approach
2
  import gradio as gr
3
  import torch
4
- import os
5
- import logging
6
- import json
7
- import requests
 
 
 
 
8
  from PIL import Image
9
- import base64
10
- import io
11
- from huggingface_hub import login
12
- from collections import defaultdict, Counter
13
- import time
14
 
15
- # Configure logging
16
- logging.basicConfig(level=logging.INFO)
17
- logger = logging.getLogger(__name__)
18
 
19
- # Usage tracking
20
- class UsageTracker:
21
  def __init__(self):
22
- self.stats = {
23
- 'total_analyses': 0,
24
- 'successful_analyses': 0,
25
- 'failed_analyses': 0,
26
- 'average_processing_time': 0.0,
27
- 'question_types': Counter()
 
 
 
 
 
28
  }
29
-
30
- def log_analysis(self, success, duration, question_type=None):
31
- self.stats['total_analyses'] += 1
32
- if success:
33
- self.stats['successful_analyses'] += 1
34
- else:
35
- self.stats['failed_analyses'] += 1
36
 
37
- total_time = self.stats['average_processing_time'] * (self.stats['total_analyses'] - 1)
38
- self.stats['average_processing_time'] = (total_time + duration) / self.stats['total_analyses']
39
 
40
- if question_type:
41
- self.stats['question_types'][question_type] += 1
42
-
43
- # Rate limiting
44
- class RateLimiter:
45
- def __init__(self, max_requests_per_hour=50):
46
- self.max_requests_per_hour = max_requests_per_hour
47
- self.requests = defaultdict(list)
48
 
49
- def is_allowed(self, user_id="default"):
50
- current_time = time.time()
51
- hour_ago = current_time - 3600
52
- self.requests[user_id] = [req_time for req_time in self.requests[user_id] if req_time > hour_ago]
53
- if len(self.requests[user_id]) < self.max_requests_per_hour:
54
- self.requests[user_id].append(current_time)
55
- return True
56
- return False
57
-
58
- # Initialize components
59
- usage_tracker = UsageTracker()
60
- rate_limiter = RateLimiter()
61
-
62
- # MedGemma API Configuration
63
- MODEL_ID = "google/medgemma-4b-it"
64
-
65
- def authenticate_hf():
66
- """Authenticate with Hugging Face"""
67
- try:
68
- hf_token = os.getenv('HF_TOKEN')
69
- if hf_token:
70
- login(token=hf_token)
71
- logger.info("✅ Authenticated with Hugging Face")
72
- return True, hf_token
73
- else:
74
- logger.warning("⚠️ No HF_TOKEN found")
75
- return False, None
76
- except Exception as e:
77
- logger.error(f"❌ Authentication failed: {e}")
78
- return False, None
79
-
80
- def image_to_base64(image):
81
- """Convert PIL image to base64 string"""
82
- try:
83
- buffer = io.BytesIO()
84
- image.save(buffer, format='PNG')
85
- img_str = base64.b64encode(buffer.getvalue()).decode()
86
- return f"data:image/png;base64,{img_str}"
87
- except Exception as e:
88
- logger.error(f"Error converting image: {e}")
89
- return None
90
-
91
- def call_medgemma_api(image, prompt, patient_history="", hf_token=None):
92
- """Call MedGemma using Hugging Face Inference API"""
93
- try:
94
- # Use HF Inference API endpoint
95
- api_url = f"https://api-inference.huggingface.co/models/{MODEL_ID}"
96
-
97
- headers = {
98
- "Authorization": f"Bearer {hf_token}",
99
- "Content-Type": "application/json"
100
- }
101
-
102
- # Prepare the payload following Google's format
103
- system_instruction = "You are an expert medical AI assistant specialized in medical image analysis. Provide detailed analysis for educational purposes only."
104
-
105
- # Build the full prompt
106
- full_prompt = system_instruction + " "
107
- if patient_history.strip():
108
- full_prompt += f"Patient History: {patient_history} "
109
- full_prompt += prompt
110
-
111
- # Convert image to base64
112
- image_b64 = image_to_base64(image)
113
- if not image_b64:
114
- return None, "Failed to process image"
115
-
116
- # Prepare the request payload
117
- payload = {
118
- "inputs": {
119
- "prompt": full_prompt,
120
- "multi_modal_data": {
121
- "image": image_b64
122
- },
123
- "max_tokens": 1000,
124
- "temperature": 0.3,
125
- "raw_response": True
126
- }
127
- }
128
-
129
- # Make the API call
130
- response = requests.post(api_url, headers=headers, json=payload, timeout=120)
131
-
132
- if response.status_code == 200:
133
- result = response.json()
134
- if isinstance(result, list) and len(result) > 0:
135
- return result[0].get('generated_text', ''), None
136
- elif isinstance(result, dict):
137
- return result.get('generated_text', result.get('text', str(result))), None
138
  else:
139
- return str(result), None
140
- else:
141
- error_msg = f"API Error {response.status_code}: {response.text}"
142
- logger.error(error_msg)
143
- return None, error_msg
 
 
 
 
 
 
 
 
 
144
 
145
- except requests.exceptions.Timeout:
146
- return None, "Request timeout - model may be loading"
147
- except Exception as e:
148
- logger.error(f"API call failed: {e}")
149
- return None, str(e)
150
-
151
- def analyze_medical_image_medgemma(image, clinical_question, patient_history=""):
152
- """Main analysis function using MedGemma"""
153
- start_time = time.time()
154
-
155
- # Rate limiting
156
- if not rate_limiter.is_allowed():
157
- usage_tracker.log_analysis(False, time.time() - start_time)
158
- return "⚠️ Too many requests. Please wait before trying again."
159
-
160
- # Validate inputs
161
- if image is None:
162
- return "⚠️ Please upload a medical image first."
163
 
164
- if not clinical_question.strip():
165
- return "⚠️ Please provide a clinical question."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
- # Authenticate
168
- auth_success, hf_token = authenticate_hf()
169
- if not auth_success or not hf_token:
170
- usage_tracker.log_analysis(False, time.time() - start_time)
171
- return """❌ **Authentication Required**
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
- To use MedGemma, you need:
174
- 1. Access to the model at https://huggingface.co/google/medgemma-4b-it
175
- 2. HF_TOKEN set in Space Settings → Repository secrets
176
 
177
- **Current Status**: Authentication failed - cannot access MedGemma."""
 
 
 
178
 
179
- try:
180
- logger.info("Calling MedGemma API...")
181
-
182
- # Call MedGemma API
183
- response_text, error = call_medgemma_api(
184
- image=image,
185
- prompt=clinical_question,
186
- patient_history=patient_history,
187
- hf_token=hf_token
188
- )
189
-
190
- if error:
191
- usage_tracker.log_analysis(False, time.time() - start_time)
192
- return f"""❌ **MedGemma API Error**
193
-
194
- {error}
195
-
196
- **Possible solutions:**
197
- 1. The model may be loading - try again in a few minutes
198
- 2. Check if you have proper access to MedGemma
199
- 3. Verify your HF_TOKEN is valid
200
-
201
- **Note**: MedGemma is a gated model and may have usage limits."""
202
-
203
- if not response_text:
204
- usage_tracker.log_analysis(False, time.time() - start_time)
205
- return "❌ No response from MedGemma. Please try again."
206
-
207
- # Clean up response
208
- response_text = response_text.strip()
209
-
210
- # Add medical disclaimer
211
- disclaimer = """
212
-
213
- ---
214
- ### ⚠️ MEDICAL DISCLAIMER
215
- **This analysis is for educational and research purposes only.**
216
- - This AI assistant is not a substitute for professional medical advice
217
- - Always consult qualified healthcare professionals for diagnosis and treatment
218
- - Do not make medical decisions based solely on this analysis
219
- - In case of medical emergency, contact emergency services immediately
220
- ---
221
-
222
- **Powered by**: Google MedGemma-4B via Hugging Face Inference API
223
- """
224
-
225
- # Log successful analysis
226
- duration = time.time() - start_time
227
- question_type = classify_question(clinical_question)
228
- usage_tracker.log_analysis(True, duration, question_type)
229
-
230
- logger.info("✅ MedGemma analysis completed successfully")
231
- return response_text + disclaimer
232
-
233
- except Exception as e:
234
- duration = time.time() - start_time
235
- usage_tracker.log_analysis(False, duration)
236
- logger.error(f"❌ Analysis error: {str(e)}")
237
- return f"❌ Analysis failed: {str(e)}\n\nPlease try again or use a different image."
238
-
239
- def classify_question(question):
240
- """Classify clinical question type"""
241
- question_lower = question.lower()
242
- if any(word in question_lower for word in ['describe', 'findings', 'observe']):
243
- return 'descriptive'
244
- elif any(word in question_lower for word in ['diagnosis', 'differential', 'condition']):
245
- return 'diagnostic'
246
- elif any(word in question_lower for word in ['abnormal', 'pathology', 'disease']):
247
- return 'pathological'
248
- else:
249
- return 'general'
250
-
251
- def get_usage_stats():
252
- """Get usage statistics"""
253
- stats = usage_tracker.stats
254
- if stats['total_analyses'] == 0:
255
- return "📊 **Usage Statistics**\n\nNo analyses performed yet."
256
 
257
- success_rate = (stats['successful_analyses'] / stats['total_analyses']) * 100
 
258
 
259
- return f"""📊 **Usage Statistics**
260
-
261
- **Performance:**
262
- - Total Analyses: {stats['total_analyses']}
263
- - Success Rate: {success_rate:.1f}%
264
- - Avg Processing Time: {stats['average_processing_time']:.2f}s
265
-
266
- **Popular Question Types:**
267
- {chr(10).join([f"- {qtype}: {count}" for qtype, count in stats['question_types'].most_common(3)])}
268
- """
269
-
270
- # Create Gradio interface
271
- def create_interface():
272
- # Check authentication status
273
- auth_success, _ = authenticate_hf()
 
 
274
 
275
- with gr.Blocks(
276
- title="MedGemma Medical Analysis",
277
- theme=gr.themes.Soft(),
278
- css="""
279
- .gradio-container { max-width: 1200px !important; }
280
- .disclaimer { background-color: #fef2f2; border: 1px solid #fecaca; border-radius: 8px; padding: 16px; margin: 16px 0; }
281
- .success { background-color: #f0f9ff; border: 1px solid #bae6fd; border-radius: 8px; padding: 16px; margin: 16px 0; }
282
- .warning { background-color: #fffbeb; border: 1px solid #fed7aa; border-radius: 8px; padding: 16px; margin: 16px 0; }
283
- """
284
- ) as demo:
285
-
286
- # Header
287
- gr.Markdown("""
288
- # 🏥 MedGemma Medical Image Analysis
289
-
290
- **Google's Medical AI Assistant - MedGemma-4B**
291
-
292
- Specialized medical AI trained specifically for:
293
- 🫁 **Radiology** • 🔬 **Histopathology** • 👁️ **Ophthalmology** • 🩺 **Dermatology**
294
- """)
295
-
296
- # Status display
297
- if auth_success:
298
- gr.Markdown("""
299
- <div class="success">
300
- ✅ <strong>MEDGEMMA READY</strong><br>
301
- Authenticated with Google's MedGemma-4B model. Ready for professional medical image analysis.
302
- </div>
303
- """)
304
- else:
305
- gr.Markdown("""
306
- <div class="warning">
307
- 🔐 <strong>AUTHENTICATION REQUIRED</strong><br>
308
- Please ensure HF_TOKEN is set in Space Settings → Repository secrets and you have access to MedGemma.
309
- </div>
310
- """)
311
-
312
- # Medical disclaimer
313
- gr.Markdown("""
314
- <div class="disclaimer">
315
- ⚠️ <strong>IMPORTANT MEDICAL DISCLAIMER</strong><br>
316
- This tool is for <strong>educational and research purposes only</strong>.
317
- Do not upload real patient data. Always consult qualified healthcare professionals.
318
- </div>
319
- """)
320
-
321
- with gr.Row():
322
- # Left column
323
- with gr.Column(scale=2):
324
- with gr.Row():
325
- with gr.Column():
326
- gr.Markdown("## 📤 Medical Image")
327
- image_input = gr.Image(
328
- label="Upload Medical Image",
329
- type="pil",
330
- height=300
331
- )
332
-
333
- with gr.Column():
334
- gr.Markdown("## 💬 Clinical Query")
335
- clinical_question = gr.Textbox(
336
- label="Clinical Question *",
337
- placeholder="Examples:\n• Describe this X-ray systematically\n• What pathological changes are visible?\n• Provide differential diagnosis\n• Assess image quality and findings",
338
- lines=4
339
- )
340
-
341
- patient_history = gr.Textbox(
342
- label="Patient History (Optional)",
343
- placeholder="e.g., 65-year-old male with chronic cough, smoking history",
344
- lines=2
345
  )
 
346
 
347
- with gr.Row():
348
- clear_btn = gr.Button("🗑️ Clear", variant="secondary")
349
- analyze_btn = gr.Button("🔍 Analyze with MedGemma", variant="primary", size="lg")
350
-
351
- gr.Markdown("## 📋 MedGemma Analysis")
352
- output = gr.Textbox(
353
- label="Medical AI Analysis Results",
354
- lines=20,
355
- show_copy_button=True,
356
- placeholder="Upload a medical image and ask a clinical question to get started..."
357
- )
 
358
 
359
- # Right column - System info
360
- with gr.Column(scale=1):
361
- gr.Markdown("## ℹ️ System Status")
362
-
363
- auth_status = " Authenticated" if auth_success else "🔐 Auth Required"
364
-
365
- gr.Markdown(f"""
366
- **Authentication:** {auth_status}
367
- **Model:** Google MedGemma-4B
368
- **API:** Hugging Face Inference
369
- **Status:** {'Ready' if auth_success else 'Setup Required'}
370
- """)
 
 
 
 
 
 
 
 
371
 
372
- gr.Markdown("## 📊 Usage Statistics")
373
- stats_display = gr.Markdown("")
374
- refresh_stats_btn = gr.Button("🔄 Refresh Stats", size="sm")
375
-
376
- gr.Markdown("## 🎯 Quick Examples")
377
-
378
- chest_btn = gr.Button("Chest X-ray", size="sm")
379
- pathology_btn = gr.Button("Pathology", size="sm")
380
- diagnosis_btn = gr.Button("Diagnosis", size="sm")
381
-
382
- # Example cases
383
- with gr.Accordion("📚 Medical Cases", open=False):
384
- examples = gr.Examples(
385
- examples=[
386
- [
387
- "https://upload.wikimedia.org/wikipedia/commons/c/c8/Chest_Xray_PA_3-8-2010.png",
388
- "You are an expert radiologist. Describe this X-ray systematically including heart size, lung fields, and any abnormalities.",
389
- "Adult patient with respiratory symptoms"
390
- ]
391
- ],
392
- inputs=[image_input, clinical_question, patient_history]
393
- )
394
-
395
- # Event handlers
396
- analyze_btn.click(
397
- fn=analyze_medical_image_medgemma,
398
- inputs=[image_input, clinical_question, patient_history],
399
- outputs=output,
400
- show_progress=True
401
- )
402
-
403
- clear_btn.click(
404
- fn=lambda: (None, "", "", ""),
405
- outputs=[image_input, clinical_question, patient_history, output]
406
- )
407
-
408
- refresh_stats_btn.click(
409
- fn=get_usage_stats,
410
- outputs=stats_display
411
- )
412
-
413
- # Quick example handlers
414
- chest_btn.click(
415
- fn=lambda: ("Analyze this chest X-ray systematically. Comment on cardiac silhouette, lung fields, mediastinum, and any pathological findings.", "Adult with respiratory symptoms"),
416
- outputs=[clinical_question, patient_history]
417
- )
418
-
419
- pathology_btn.click(
420
- fn=lambda: ("What pathological changes are visible in this medical image? Provide structured analysis with clinical significance.", ""),
421
- outputs=[clinical_question, patient_history]
422
- )
423
-
424
- diagnosis_btn.click(
425
- fn=lambda: ("Based on the imaging findings, what are the most likely differential diagnoses? Consider clinical context.", "Patient with acute presentation"),
426
- outputs=[clinical_question, patient_history]
427
- )
428
-
429
- # Footer
430
- gr.Markdown("""
431
- ---
432
- ### 🔬 About MedGemma
433
-
434
- **MedGemma-4B** is Google's specialized medical AI model designed specifically for medical image analysis and clinical reasoning.
435
- It represents state-of-the-art performance in medical AI applications.
436
-
437
- **Key Features:**
438
- - **Medical Specialization**: Trained specifically on medical imaging data
439
- - **Multi-modal**: Handles both images and clinical text
440
- - **Professional Grade**: Designed for medical education and research
441
- - **Google Quality**: Built by Google's medical AI team
442
-
443
- ### 🔒 Privacy & Compliance
444
- - **Real-time processing** with no data retention
445
- - **Educational purpose** design and disclaimers
446
- - **HIPAA-aware** interface (no PHI uploads)
447
- - **Professional standards** for medical AI applications
448
 
449
- **Model:** Google MedGemma-4B | **API:** Hugging Face Inference | **License:** Apache 2.0
450
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
 
452
- return demo
 
 
453
 
454
- # Launch the app
455
  if __name__ == "__main__":
456
- demo = create_interface()
457
- demo.launch(
458
- server_name="0.0.0.0",
459
- server_port=7860,
460
- show_error=True
461
- )
 
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import (
4
+ AutoModelForCausalLM,
5
+ AutoModelForImageTextToText,
6
+ AutoTokenizer,
7
+ AutoProcessor,
8
+ BitsAndBytesConfig,
9
+ pipeline
10
+ )
11
  from PIL import Image
12
+ import os
13
+ import spaces
 
 
 
14
 
15
+ # Configuration
16
+ MODEL_4B = "google/medgemma-4b-it"
17
+ MODEL_27B = "google/medgemma-27b-text-it"
18
 
19
+ class MedGemmaApp:
 
20
  def __init__(self):
21
+ self.current_model = None
22
+ self.current_tokenizer = None
23
+ self.current_processor = None
24
+ self.current_pipe = None
25
+ self.model_type = None
26
+
27
+ def get_model_kwargs(self, use_quantization=True):
28
+ """Get model configuration arguments"""
29
+ model_kwargs = {
30
+ "torch_dtype": torch.bfloat16,
31
+ "device_map": "auto",
32
  }
 
 
 
 
 
 
 
33
 
34
+ if use_quantization:
35
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)
36
 
37
+ return model_kwargs
 
 
 
 
 
 
 
38
 
39
+ @spaces.GPU
40
+ def load_model(self, model_choice, use_quantization=True):
41
+ """Load the selected model"""
42
+ try:
43
+ model_id = MODEL_4B if model_choice == "4B (Multimodal)" else MODEL_27B
44
+ model_kwargs = self.get_model_kwargs(use_quantization)
45
+
46
+ # Clear previous model
47
+ if self.current_model is not None:
48
+ del self.current_model
49
+ del self.current_tokenizer
50
+ if self.current_processor:
51
+ del self.current_processor
52
+ if self.current_pipe:
53
+ del self.current_pipe
54
+ torch.cuda.empty_cache()
55
+
56
+ if model_choice == "4B (Multimodal)":
57
+ # Load multimodal model
58
+ self.current_model = AutoModelForImageTextToText.from_pretrained(
59
+ model_id, **model_kwargs
60
+ )
61
+ self.current_processor = AutoProcessor.from_pretrained(model_id)
62
+ self.model_type = "multimodal"
63
+
64
+ # Create pipeline for easier inference
65
+ self.current_pipe = pipeline(
66
+ "image-text-to-text",
67
+ model=self.current_model,
68
+ processor=self.current_processor,
69
+ )
70
+ self.current_pipe.model.generation_config.do_sample = False
71
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  else:
73
+ # Load text-only model
74
+ self.current_model = AutoModelForCausalLM.from_pretrained(
75
+ model_id, **model_kwargs
76
+ )
77
+ self.current_tokenizer = AutoTokenizer.from_pretrained(model_id)
78
+ self.model_type = "text"
79
+
80
+ # Create pipeline for easier inference
81
+ self.current_pipe = pipeline(
82
+ "text-generation",
83
+ model=self.current_model,
84
+ tokenizer=self.current_tokenizer,
85
+ )
86
+ self.current_pipe.model.generation_config.do_sample = False
87
 
88
+ return f"✅ Successfully loaded {model_choice} model!"
89
+
90
+ except Exception as e:
91
+ return f" Error loading model: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
+ @spaces.GPU
94
+ def chat_text_only(self, message, history, system_instruction="You are a helpful medical assistant."):
95
+ """Handle text-only conversations"""
96
+ if self.current_model is None or self.model_type != "text":
97
+ return "Please load the 27B (Text Only) model first!"
98
+
99
+ try:
100
+ messages = [
101
+ {"role": "system", "content": system_instruction},
102
+ {"role": "user", "content": message}
103
+ ]
104
+
105
+ # Add conversation history
106
+ for human, assistant in history:
107
+ messages.insert(-1, {"role": "user", "content": human})
108
+ messages.insert(-1, {"role": "assistant", "content": assistant})
109
+
110
+ output = self.current_pipe(messages, max_new_tokens=500)
111
+ response = output[0]["generated_text"][-1]["content"]
112
+
113
+ return response
114
+
115
+ except Exception as e:
116
+ return f"Error generating response: {str(e)}"
117
 
118
+ @spaces.GPU
119
+ def chat_with_image(self, message, image, system_instruction="You are an expert radiologist."):
120
+ """Handle image + text conversations"""
121
+ if self.current_model is None or self.model_type != "multimodal":
122
+ return "Please load the 4B (Multimodal) model first!"
123
+
124
+ if image is None:
125
+ return "Please upload an image to analyze."
126
+
127
+ try:
128
+ messages = [
129
+ {
130
+ "role": "system",
131
+ "content": [{"type": "text", "text": system_instruction}]
132
+ },
133
+ {
134
+ "role": "user",
135
+ "content": [
136
+ {"type": "text", "text": message},
137
+ {"type": "image", "image": image}
138
+ ]
139
+ }
140
+ ]
141
+
142
+ output = self.current_pipe(text=messages, max_new_tokens=300)
143
+ response = output[0]["generated_text"][-1]["content"]
144
+
145
+ return response
146
+
147
+ except Exception as e:
148
+ return f"Error analyzing image: {str(e)}"
149
 
150
+ # Initialize the app
151
+ app = MedGemmaApp()
 
152
 
153
+ # Create Gradio interface
154
+ with gr.Blocks(title="MedGemma Medical AI Assistant", theme=gr.themes.Soft()) as demo:
155
+ gr.Markdown("""
156
+ # 🏥 MedGemma Medical AI Assistant
157
 
158
+ Welcome to MedGemma, Google's medical AI assistant! Choose between:
159
+ - **4B Multimodal**: Analyze medical images (X-rays, scans) with text
160
+ - **27B Text-Only**: Advanced medical text conversations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
+ > **Note**: This is for educational and research purposes only. Always consult healthcare professionals for medical advice.
163
+ """)
164
 
165
+ with gr.Row():
166
+ with gr.Column(scale=1):
167
+ model_choice = gr.Radio(
168
+ choices=["4B (Multimodal)", "27B (Text Only)"],
169
+ value="4B (Multimodal)",
170
+ label="Select Model",
171
+ info="4B supports images, 27B is text-only but more powerful"
172
+ )
173
+
174
+ use_quantization = gr.Checkbox(
175
+ value=True,
176
+ label="Use 4-bit Quantization",
177
+ info="Reduces memory usage (recommended)"
178
+ )
179
+
180
+ load_btn = gr.Button("🚀 Load Model", variant="primary")
181
+ model_status = gr.Textbox(label="Model Status", interactive=False)
182
 
183
+ with gr.Tabs():
184
+ # Text-only chat tab
185
+ with gr.Tab("💬 Text Chat", id="text_chat"):
186
+ gr.Markdown("### Medical Text Consultation")
187
+
188
+ with gr.Row():
189
+ with gr.Column(scale=3):
190
+ text_system = gr.Textbox(
191
+ value="You are a helpful medical assistant.",
192
+ label="System Instruction",
193
+ placeholder="Set the AI's role and behavior..."
194
+ )
195
+
196
+ chatbot_text = gr.Chatbot(
197
+ height=400,
198
+ placeholder="Start a medical conversation...",
199
+ label="Medical Assistant"
200
+ )
201
+
202
+ with gr.Row():
203
+ text_input = gr.Textbox(
204
+ placeholder="Ask a medical question...",
205
+ label="Your Question",
206
+ scale=4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  )
208
+ text_submit = gr.Button("Send", scale=1)
209
 
210
+ with gr.Column(scale=1):
211
+ gr.Markdown("""
212
+ ### 💡 Example Questions:
213
+ - How do you differentiate bacterial from viral pneumonia?
214
+ - What are the symptoms of diabetes?
215
+ - Explain the mechanism of action of ACE inhibitors
216
+ - What are the contraindications for MRI?
217
+ """)
218
+
219
+ # Image analysis tab
220
+ with gr.Tab("🖼️ Image Analysis", id="image_analysis"):
221
+ gr.Markdown("### Medical Image Analysis")
222
 
223
+ with gr.Row():
224
+ with gr.Column(scale=2):
225
+ image_input = gr.Image(
226
+ type="pil",
227
+ label="Upload Medical Image",
228
+ height=300
229
+ )
230
+
231
+ image_system = gr.Textbox(
232
+ value="You are an expert radiologist.",
233
+ label="System Instruction"
234
+ )
235
+
236
+ image_text_input = gr.Textbox(
237
+ value="Describe this X-ray",
238
+ label="Question about the image",
239
+ placeholder="What would you like to know about this image?"
240
+ )
241
+
242
+ image_submit = gr.Button("🔍 Analyze Image", variant="primary")
243
 
244
+ with gr.Column(scale=2):
245
+ image_output = gr.Textbox(
246
+ label="Analysis Result",
247
+ lines=15,
248
+ placeholder="Upload an image and click 'Analyze Image' to see the AI's analysis..."
249
+ )
250
+
251
+ # Event handlers
252
+ load_btn.click(
253
+ fn=app.load_model,
254
+ inputs=[model_choice, use_quantization],
255
+ outputs=[model_status]
256
+ )
257
+
258
+ def respond_text(message, history, system_instruction):
259
+ if message.strip() == "":
260
+ return history, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
+ response = app.chat_text_only(message, history, system_instruction)
263
+ history.append((message, response))
264
+ return history, ""
265
+
266
+ text_submit.click(
267
+ fn=respond_text,
268
+ inputs=[text_input, chatbot_text, text_system],
269
+ outputs=[chatbot_text, text_input]
270
+ )
271
+
272
+ text_input.submit(
273
+ fn=respond_text,
274
+ inputs=[text_input, chatbot_text, text_system],
275
+ outputs=[chatbot_text, text_input]
276
+ )
277
+
278
+ image_submit.click(
279
+ fn=app.chat_with_image,
280
+ inputs=[image_text_input, image_input, image_system],
281
+ outputs=[image_output]
282
+ )
283
+
284
+ # Example image loading
285
+ gr.Markdown("""
286
+ ---
287
+ ### 📚 About MedGemma
288
+ MedGemma is a collection of Gemma variants trained for medical applications.
289
+ Learn more at the [HAI-DEF developer site](https://developers.google.com/health-ai-developer-foundations/medgemma).
290
 
291
+ **Disclaimer**: This tool is for educational and research purposes only.
292
+ Always consult qualified healthcare professionals for medical advice.
293
+ """)
294
 
 
295
  if __name__ == "__main__":
296
+ demo.launch()