File size: 6,219 Bytes
b458509
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
# app.py - Main Gradio application
import gradio as gr
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText
from PIL import Image
import os

# Model configuration
MODEL_ID = "google/medgemma-4b-it"

# Load model and processor
@gr.utils.async_wrapper
def load_model():
    model = AutoModelForImageTextToText.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True
    )
    processor = AutoProcessor.from_pretrained(MODEL_ID)
    return model, processor

# Initialize model (this will be cached)
model, processor = load_model()

def analyze_medical_image(image, clinical_question, patient_history=""):
    """
    Analyze medical image with clinical context
    """
    if image is None:
        return "Please upload an image first."
    
    # Prepare the conversation
    messages = [
        {
            "role": "system",
            "content": [{"type": "text", "text": "You are an expert medical AI assistant. Provide detailed analysis while emphasizing that this is for educational purposes only and should not replace professional medical diagnosis."}]
        }
    ]
    
    # Add patient history if provided
    if patient_history.strip():
        messages.append({
            "role": "user", 
            "content": [{"type": "text", "text": f"Patient History: {patient_history}"}]
        })
    
    # Add the main question with image
    user_content = [{"type": "text", "text": clinical_question}]
    if image:
        user_content.append({"type": "image", "image": image})
    
    messages.append({
        "role": "user",
        "content": user_content
    })
    
    try:
        # Process inputs
        inputs = processor.apply_chat_template(
            messages, 
            add_generation_prompt=True, 
            tokenize=True,
            return_dict=True, 
            return_tensors="pt"
        ).to(model.device, dtype=torch.bfloat16)
        
        input_len = inputs["input_ids"].shape[-1]
        
        # Generate response
        with torch.inference_mode():
            generation = model.generate(
                **inputs, 
                max_new_tokens=1000,
                do_sample=True,
                temperature=0.7,
                top_p=0.9
            )
            generation = generation[0][input_len:]
        
        # Decode response
        response = processor.decode(generation, skip_special_tokens=True)
        
        # Add disclaimer
        disclaimer = "\n\n⚠️ IMPORTANT DISCLAIMER: This analysis is for educational and research purposes only. Always consult qualified healthcare professionals for medical diagnosis and treatment decisions."
        
        return response + disclaimer
        
    except Exception as e:
        return f"Error processing request: {str(e)}"

# Create Gradio interface
def create_interface():
    with gr.Blocks(title="MedGemma Medical Image Analysis", theme=gr.themes.Soft()) as demo:
        gr.Markdown("""
        # 🏥 MedGemma Medical Image Analysis
        
        **Educational Medical AI Assistant powered by Google's MedGemma-4B**
        
        ⚠️ **Important**: This tool is for educational and research purposes only. 
        Do not use real patient data. Always consult healthcare professionals for medical decisions.
        """)
        
        with gr.Row():
            with gr.Column(scale=1):
                image_input = gr.Image(
                    label="Medical Image", 
                    type="pil",
                    height=400
                )
                
                clinical_question = gr.Textbox(
                    label="Clinical Question",
                    placeholder="e.g., 'Describe the findings in this chest X-ray' or 'What pathological changes do you observe?'",
                    lines=3
                )
                
                patient_history = gr.Textbox(
                    label="Patient History (Optional)",
                    placeholder="e.g., '45-year-old male with chronic cough and shortness of breath'",
                    lines=3
                )
                
                analyze_btn = gr.Button("🔍 Analyze Image", variant="primary", size="lg")
                
            with gr.Column(scale=1):
                output = gr.Textbox(
                    label="Medical Analysis",
                    lines=20,
                    max_lines=30
                )
        
        # Example cases
        gr.Markdown("## 📋 Example Cases")
        
        examples = gr.Examples(
            examples=[
                [
                    "https://upload.wikimedia.org/wikipedia/commons/c/c8/Chest_Xray_PA_3-8-2010.png",
                    "Describe the findings in this chest X-ray and identify any abnormalities.",
                    "Adult patient with respiratory symptoms"
                ],
                [
                    None,  # User will upload their own
                    "What pathological changes are visible in this medical image?",
                    ""
                ],
                [
                    None,
                    "Provide a differential diagnosis based on the imaging findings.",
                    "Patient presenting with acute symptoms"
                ]
            ],
            inputs=[image_input, clinical_question, patient_history]
        )
        
        # Event handlers
        analyze_btn.click(
            fn=analyze_medical_image,
            inputs=[image_input, clinical_question, patient_history],
            outputs=output
        )
        
        # Footer
        gr.Markdown("""
        ---
        **Data Privacy Notice**: Do not upload real patient data or personally identifiable information. 
        Use only synthetic, anonymized, or publicly available medical images for demonstration purposes.
        
        **Model**: Google MedGemma-4B | **Purpose**: Educational and Research Use Only
        """)
    
    return demo

# Launch the app
if __name__ == "__main__":
    demo = create_interface()
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=True
    )