walaa2022 commited on
Commit
b458509
·
verified ·
1 Parent(s): ab0ad58

Upload 2 files

Browse files
Files changed (2) hide show
  1. medgemma_space.py +182 -0
  2. requirements_txt.txt +7 -0
medgemma_space.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py - Main Gradio application
2
+ import gradio as gr
3
+ import torch
4
+ from transformers import AutoProcessor, AutoModelForImageTextToText
5
+ from PIL import Image
6
+ import os
7
+
8
+ # Model configuration
9
+ MODEL_ID = "google/medgemma-4b-it"
10
+
11
+ # Load model and processor
12
+ @gr.utils.async_wrapper
13
+ def load_model():
14
+ model = AutoModelForImageTextToText.from_pretrained(
15
+ MODEL_ID,
16
+ torch_dtype=torch.bfloat16,
17
+ device_map="auto",
18
+ trust_remote_code=True
19
+ )
20
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
21
+ return model, processor
22
+
23
+ # Initialize model (this will be cached)
24
+ model, processor = load_model()
25
+
26
+ def analyze_medical_image(image, clinical_question, patient_history=""):
27
+ """
28
+ Analyze medical image with clinical context
29
+ """
30
+ if image is None:
31
+ return "Please upload an image first."
32
+
33
+ # Prepare the conversation
34
+ messages = [
35
+ {
36
+ "role": "system",
37
+ "content": [{"type": "text", "text": "You are an expert medical AI assistant. Provide detailed analysis while emphasizing that this is for educational purposes only and should not replace professional medical diagnosis."}]
38
+ }
39
+ ]
40
+
41
+ # Add patient history if provided
42
+ if patient_history.strip():
43
+ messages.append({
44
+ "role": "user",
45
+ "content": [{"type": "text", "text": f"Patient History: {patient_history}"}]
46
+ })
47
+
48
+ # Add the main question with image
49
+ user_content = [{"type": "text", "text": clinical_question}]
50
+ if image:
51
+ user_content.append({"type": "image", "image": image})
52
+
53
+ messages.append({
54
+ "role": "user",
55
+ "content": user_content
56
+ })
57
+
58
+ try:
59
+ # Process inputs
60
+ inputs = processor.apply_chat_template(
61
+ messages,
62
+ add_generation_prompt=True,
63
+ tokenize=True,
64
+ return_dict=True,
65
+ return_tensors="pt"
66
+ ).to(model.device, dtype=torch.bfloat16)
67
+
68
+ input_len = inputs["input_ids"].shape[-1]
69
+
70
+ # Generate response
71
+ with torch.inference_mode():
72
+ generation = model.generate(
73
+ **inputs,
74
+ max_new_tokens=1000,
75
+ do_sample=True,
76
+ temperature=0.7,
77
+ top_p=0.9
78
+ )
79
+ generation = generation[0][input_len:]
80
+
81
+ # Decode response
82
+ response = processor.decode(generation, skip_special_tokens=True)
83
+
84
+ # Add disclaimer
85
+ disclaimer = "\n\n⚠️ IMPORTANT DISCLAIMER: This analysis is for educational and research purposes only. Always consult qualified healthcare professionals for medical diagnosis and treatment decisions."
86
+
87
+ return response + disclaimer
88
+
89
+ except Exception as e:
90
+ return f"Error processing request: {str(e)}"
91
+
92
+ # Create Gradio interface
93
+ def create_interface():
94
+ with gr.Blocks(title="MedGemma Medical Image Analysis", theme=gr.themes.Soft()) as demo:
95
+ gr.Markdown("""
96
+ # 🏥 MedGemma Medical Image Analysis
97
+
98
+ **Educational Medical AI Assistant powered by Google's MedGemma-4B**
99
+
100
+ ⚠️ **Important**: This tool is for educational and research purposes only.
101
+ Do not use real patient data. Always consult healthcare professionals for medical decisions.
102
+ """)
103
+
104
+ with gr.Row():
105
+ with gr.Column(scale=1):
106
+ image_input = gr.Image(
107
+ label="Medical Image",
108
+ type="pil",
109
+ height=400
110
+ )
111
+
112
+ clinical_question = gr.Textbox(
113
+ label="Clinical Question",
114
+ placeholder="e.g., 'Describe the findings in this chest X-ray' or 'What pathological changes do you observe?'",
115
+ lines=3
116
+ )
117
+
118
+ patient_history = gr.Textbox(
119
+ label="Patient History (Optional)",
120
+ placeholder="e.g., '45-year-old male with chronic cough and shortness of breath'",
121
+ lines=3
122
+ )
123
+
124
+ analyze_btn = gr.Button("🔍 Analyze Image", variant="primary", size="lg")
125
+
126
+ with gr.Column(scale=1):
127
+ output = gr.Textbox(
128
+ label="Medical Analysis",
129
+ lines=20,
130
+ max_lines=30
131
+ )
132
+
133
+ # Example cases
134
+ gr.Markdown("## 📋 Example Cases")
135
+
136
+ examples = gr.Examples(
137
+ examples=[
138
+ [
139
+ "https://upload.wikimedia.org/wikipedia/commons/c/c8/Chest_Xray_PA_3-8-2010.png",
140
+ "Describe the findings in this chest X-ray and identify any abnormalities.",
141
+ "Adult patient with respiratory symptoms"
142
+ ],
143
+ [
144
+ None, # User will upload their own
145
+ "What pathological changes are visible in this medical image?",
146
+ ""
147
+ ],
148
+ [
149
+ None,
150
+ "Provide a differential diagnosis based on the imaging findings.",
151
+ "Patient presenting with acute symptoms"
152
+ ]
153
+ ],
154
+ inputs=[image_input, clinical_question, patient_history]
155
+ )
156
+
157
+ # Event handlers
158
+ analyze_btn.click(
159
+ fn=analyze_medical_image,
160
+ inputs=[image_input, clinical_question, patient_history],
161
+ outputs=output
162
+ )
163
+
164
+ # Footer
165
+ gr.Markdown("""
166
+ ---
167
+ **Data Privacy Notice**: Do not upload real patient data or personally identifiable information.
168
+ Use only synthetic, anonymized, or publicly available medical images for demonstration purposes.
169
+
170
+ **Model**: Google MedGemma-4B | **Purpose**: Educational and Research Use Only
171
+ """)
172
+
173
+ return demo
174
+
175
+ # Launch the app
176
+ if __name__ == "__main__":
177
+ demo = create_interface()
178
+ demo.launch(
179
+ server_name="0.0.0.0",
180
+ server_port=7860,
181
+ share=True
182
+ )
requirements_txt.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.40.0
3
+ gradio>=4.0.0
4
+ Pillow>=9.0.0
5
+ accelerate>=0.20.0
6
+ requests>=2.28.0
7
+ numpy>=1.21.0