|
import gradio as gr |
|
import torch |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoModelForImageTextToText, |
|
AutoTokenizer, |
|
AutoProcessor, |
|
pipeline |
|
) |
|
from PIL import Image |
|
import os |
|
import spaces |
|
|
|
|
|
try: |
|
from transformers import BitsAndBytesConfig |
|
QUANTIZATION_AVAILABLE = True |
|
except ImportError: |
|
QUANTIZATION_AVAILABLE = False |
|
print("β οΈ bitsandbytes not available. Quantization will be disabled.") |
|
|
|
|
|
MODEL_4B = "google/medgemma-4b-it" |
|
MODEL_27B = "google/medgemma-27b-text-it" |
|
|
|
class MedGemmaApp: |
|
def __init__(self): |
|
self.current_model = None |
|
self.current_tokenizer = None |
|
self.current_processor = None |
|
self.current_pipe = None |
|
self.model_type = None |
|
|
|
def get_model_kwargs(self, use_quantization=True): |
|
"""Get model configuration arguments""" |
|
model_kwargs = { |
|
"torch_dtype": torch.bfloat16, |
|
"device_map": "auto", |
|
} |
|
|
|
|
|
if use_quantization and QUANTIZATION_AVAILABLE: |
|
model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True) |
|
elif use_quantization and not QUANTIZATION_AVAILABLE: |
|
print("β οΈ Quantization requested but bitsandbytes not available. Loading without quantization.") |
|
|
|
return model_kwargs |
|
|
|
@spaces.GPU |
|
def load_model(self, model_choice, use_quantization=True): |
|
"""Load the selected model""" |
|
try: |
|
model_id = MODEL_4B if model_choice == "4B (Multimodal)" else MODEL_27B |
|
model_kwargs = self.get_model_kwargs(use_quantization) |
|
|
|
|
|
if self.current_model is not None: |
|
del self.current_model |
|
del self.current_tokenizer |
|
if self.current_processor: |
|
del self.current_processor |
|
if self.current_pipe: |
|
del self.current_pipe |
|
torch.cuda.empty_cache() |
|
|
|
if model_choice == "4B (Multimodal)": |
|
|
|
self.current_model = AutoModelForImageTextToText.from_pretrained( |
|
model_id, **model_kwargs |
|
) |
|
self.current_processor = AutoProcessor.from_pretrained(model_id) |
|
self.model_type = "multimodal" |
|
|
|
|
|
self.current_pipe = pipeline( |
|
"image-text-to-text", |
|
model=self.current_model, |
|
processor=self.current_processor, |
|
) |
|
self.current_pipe.model.generation_config.do_sample = False |
|
|
|
else: |
|
|
|
self.current_model = AutoModelForCausalLM.from_pretrained( |
|
model_id, **model_kwargs |
|
) |
|
self.current_tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
self.model_type = "text" |
|
|
|
|
|
self.current_pipe = pipeline( |
|
"text-generation", |
|
model=self.current_model, |
|
tokenizer=self.current_tokenizer, |
|
) |
|
self.current_pipe.model.generation_config.do_sample = False |
|
|
|
return f"β
Successfully loaded {model_choice} model!" |
|
|
|
except Exception as e: |
|
return f"β Error loading model: {str(e)}" |
|
|
|
@spaces.GPU |
|
def chat_text_only(self, message, history, system_instruction="You are a helpful medical assistant."): |
|
"""Handle text-only conversations""" |
|
if self.current_model is None or self.model_type != "text": |
|
return "Please load the 27B (Text Only) model first!" |
|
|
|
try: |
|
messages = [ |
|
{"role": "system", "content": system_instruction}, |
|
{"role": "user", "content": message} |
|
] |
|
|
|
|
|
for human, assistant in history: |
|
messages.insert(-1, {"role": "user", "content": human}) |
|
messages.insert(-1, {"role": "assistant", "content": assistant}) |
|
|
|
output = self.current_pipe(messages, max_new_tokens=500) |
|
response = output[0]["generated_text"][-1]["content"] |
|
|
|
return response |
|
|
|
except Exception as e: |
|
return f"Error generating response: {str(e)}" |
|
|
|
@spaces.GPU |
|
def chat_with_image(self, message, image, system_instruction="You are an expert radiologist."): |
|
"""Handle image + text conversations""" |
|
if self.current_model is None or self.model_type != "multimodal": |
|
return "Please load the 4B (Multimodal) model first!" |
|
|
|
if image is None: |
|
return "Please upload an image to analyze." |
|
|
|
try: |
|
messages = [ |
|
{ |
|
"role": "system", |
|
"content": [{"type": "text", "text": system_instruction}] |
|
}, |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "text", "text": message}, |
|
{"type": "image", "image": image} |
|
] |
|
} |
|
] |
|
|
|
output = self.current_pipe(text=messages, max_new_tokens=300) |
|
response = output[0]["generated_text"][-1]["content"] |
|
|
|
return response |
|
|
|
except Exception as e: |
|
return f"Error analyzing image: {str(e)}" |
|
|
|
|
|
app = MedGemmaApp() |
|
|
|
|
|
with gr.Blocks(title="MedGemma Medical AI Assistant", theme=gr.themes.Soft()) as demo: |
|
gr.Markdown(""" |
|
# π₯ MedGemma Medical AI Assistant |
|
|
|
Welcome to MedGemma, Google's medical AI assistant! Choose between: |
|
- **4B Multimodal**: Analyze medical images (X-rays, scans) with text |
|
- **27B Text-Only**: Advanced medical text conversations |
|
|
|
> **Note**: This is for educational and research purposes only. Always consult healthcare professionals for medical advice. |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
model_choice = gr.Radio( |
|
choices=["4B (Multimodal)", "27B (Text Only)"], |
|
value="4B (Multimodal)", |
|
label="Select Model", |
|
info="4B supports images, 27B is text-only but more powerful" |
|
) |
|
|
|
use_quantization = gr.Checkbox( |
|
value=QUANTIZATION_AVAILABLE, |
|
label="Use 4-bit Quantization" + ("" if QUANTIZATION_AVAILABLE else " (Unavailable)"), |
|
info="Reduces memory usage" + ("" if QUANTIZATION_AVAILABLE else " - bitsandbytes not installed"), |
|
interactive=QUANTIZATION_AVAILABLE |
|
) |
|
|
|
load_btn = gr.Button("π Load Model", variant="primary") |
|
model_status = gr.Textbox(label="Model Status", interactive=False) |
|
|
|
with gr.Tabs(): |
|
|
|
with gr.Tab("π¬ Text Chat", id="text_chat"): |
|
gr.Markdown("### Medical Text Consultation") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
text_system = gr.Textbox( |
|
value="You are a helpful medical assistant.", |
|
label="System Instruction", |
|
placeholder="Set the AI's role and behavior..." |
|
) |
|
|
|
chatbot_text = gr.Chatbot( |
|
height=400, |
|
placeholder="Start a medical conversation...", |
|
label="Medical Assistant" |
|
) |
|
|
|
with gr.Row(): |
|
text_input = gr.Textbox( |
|
placeholder="Ask a medical question...", |
|
label="Your Question", |
|
scale=4 |
|
) |
|
text_submit = gr.Button("Send", scale=1) |
|
|
|
with gr.Column(scale=1): |
|
gr.Markdown(""" |
|
### π‘ Example Questions: |
|
- How do you differentiate bacterial from viral pneumonia? |
|
- What are the symptoms of diabetes? |
|
- Explain the mechanism of action of ACE inhibitors |
|
- What are the contraindications for MRI? |
|
""") |
|
|
|
|
|
with gr.Tab("πΌοΈ Image Analysis", id="image_analysis"): |
|
gr.Markdown("### Medical Image Analysis") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
image_input = gr.Image( |
|
type="pil", |
|
label="Upload Medical Image", |
|
height=300 |
|
) |
|
|
|
image_system = gr.Textbox( |
|
value="You are an expert radiologist.", |
|
label="System Instruction" |
|
) |
|
|
|
image_text_input = gr.Textbox( |
|
value="Describe this X-ray", |
|
label="Question about the image", |
|
placeholder="What would you like to know about this image?" |
|
) |
|
|
|
image_submit = gr.Button("π Analyze Image", variant="primary") |
|
|
|
with gr.Column(scale=2): |
|
image_output = gr.Textbox( |
|
label="Analysis Result", |
|
lines=15, |
|
placeholder="Upload an image and click 'Analyze Image' to see the AI's analysis..." |
|
) |
|
|
|
|
|
load_btn.click( |
|
fn=app.load_model, |
|
inputs=[model_choice, use_quantization], |
|
outputs=[model_status] |
|
) |
|
|
|
def respond_text(message, history, system_instruction): |
|
if message.strip() == "": |
|
return history, "" |
|
|
|
response = app.chat_text_only(message, history, system_instruction) |
|
history.append((message, response)) |
|
return history, "" |
|
|
|
text_submit.click( |
|
fn=respond_text, |
|
inputs=[text_input, chatbot_text, text_system], |
|
outputs=[chatbot_text, text_input] |
|
) |
|
|
|
text_input.submit( |
|
fn=respond_text, |
|
inputs=[text_input, chatbot_text, text_system], |
|
outputs=[chatbot_text, text_input] |
|
) |
|
|
|
image_submit.click( |
|
fn=app.chat_with_image, |
|
inputs=[image_text_input, image_input, image_system], |
|
outputs=[image_output] |
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
--- |
|
### π About MedGemma |
|
MedGemma is a collection of Gemma variants trained for medical applications. |
|
Learn more at the [HAI-DEF developer site](https://developers.google.com/health-ai-developer-foundations/medgemma). |
|
|
|
**Disclaimer**: This tool is for educational and research purposes only. |
|
Always consult qualified healthcare professionals for medical advice. |
|
""") |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |