import gradio as gr import torch from transformers import ( AutoModelForCausalLM, AutoModelForImageTextToText, AutoTokenizer, AutoProcessor, pipeline ) from PIL import Image import os import spaces # Try to import bitsandbytes for quantization (optional) try: from transformers import BitsAndBytesConfig QUANTIZATION_AVAILABLE = True except ImportError: QUANTIZATION_AVAILABLE = False print("⚠️ bitsandbytes not available. Quantization will be disabled.") # Configuration MODEL_4B = "google/medgemma-4b-it" MODEL_27B = "google/medgemma-27b-text-it" class MedGemmaApp: def __init__(self): self.current_model = None self.current_tokenizer = None self.current_processor = None self.current_pipe = None self.model_type = None def get_model_kwargs(self, use_quantization=True): """Get model configuration arguments""" model_kwargs = { "torch_dtype": torch.bfloat16, "device_map": "auto", } # Only add quantization if available and requested if use_quantization and QUANTIZATION_AVAILABLE: model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True) elif use_quantization and not QUANTIZATION_AVAILABLE: print("⚠️ Quantization requested but bitsandbytes not available. Loading without quantization.") return model_kwargs @spaces.GPU def load_model(self, model_choice, use_quantization=True): """Load the selected model""" try: model_id = MODEL_4B if model_choice == "4B (Multimodal)" else MODEL_27B model_kwargs = self.get_model_kwargs(use_quantization) # Clear previous model if self.current_model is not None: del self.current_model del self.current_tokenizer if self.current_processor: del self.current_processor if self.current_pipe: del self.current_pipe torch.cuda.empty_cache() if model_choice == "4B (Multimodal)": # Load multimodal model self.current_model = AutoModelForImageTextToText.from_pretrained( model_id, **model_kwargs ) self.current_processor = AutoProcessor.from_pretrained(model_id) self.model_type = "multimodal" # Create pipeline for easier inference self.current_pipe = pipeline( "image-text-to-text", model=self.current_model, processor=self.current_processor, ) self.current_pipe.model.generation_config.do_sample = False else: # Load text-only model self.current_model = AutoModelForCausalLM.from_pretrained( model_id, **model_kwargs ) self.current_tokenizer = AutoTokenizer.from_pretrained(model_id) self.model_type = "text" # Create pipeline for easier inference self.current_pipe = pipeline( "text-generation", model=self.current_model, tokenizer=self.current_tokenizer, ) self.current_pipe.model.generation_config.do_sample = False return f"✅ Successfully loaded {model_choice} model!" except Exception as e: return f"❌ Error loading model: {str(e)}" @spaces.GPU def chat_text_only(self, message, history, system_instruction="You are a helpful medical assistant."): """Handle text-only conversations""" if self.current_model is None or self.model_type != "text": return "Please load the 27B (Text Only) model first!" try: messages = [ {"role": "system", "content": system_instruction}, {"role": "user", "content": message} ] # Add conversation history for human, assistant in history: messages.insert(-1, {"role": "user", "content": human}) messages.insert(-1, {"role": "assistant", "content": assistant}) output = self.current_pipe(messages, max_new_tokens=500) response = output[0]["generated_text"][-1]["content"] return response except Exception as e: return f"Error generating response: {str(e)}" @spaces.GPU def chat_with_image(self, message, image, system_instruction="You are an expert radiologist."): """Handle image + text conversations""" if self.current_model is None or self.model_type != "multimodal": return "Please load the 4B (Multimodal) model first!" if image is None: return "Please upload an image to analyze." try: messages = [ { "role": "system", "content": [{"type": "text", "text": system_instruction}] }, { "role": "user", "content": [ {"type": "text", "text": message}, {"type": "image", "image": image} ] } ] output = self.current_pipe(text=messages, max_new_tokens=300) response = output[0]["generated_text"][-1]["content"] return response except Exception as e: return f"Error analyzing image: {str(e)}" # Initialize the app app = MedGemmaApp() # Create Gradio interface with gr.Blocks(title="MedGemma Medical AI Assistant", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🏥 MedGemma Medical AI Assistant Welcome to MedGemma, Google's medical AI assistant! Choose between: - **4B Multimodal**: Analyze medical images (X-rays, scans) with text - **27B Text-Only**: Advanced medical text conversations > **Note**: This is for educational and research purposes only. Always consult healthcare professionals for medical advice. """) with gr.Row(): with gr.Column(scale=1): model_choice = gr.Radio( choices=["4B (Multimodal)", "27B (Text Only)"], value="4B (Multimodal)", label="Select Model", info="4B supports images, 27B is text-only but more powerful" ) use_quantization = gr.Checkbox( value=QUANTIZATION_AVAILABLE, label="Use 4-bit Quantization" + ("" if QUANTIZATION_AVAILABLE else " (Unavailable)"), info="Reduces memory usage" + ("" if QUANTIZATION_AVAILABLE else " - bitsandbytes not installed"), interactive=QUANTIZATION_AVAILABLE ) load_btn = gr.Button("🚀 Load Model", variant="primary") model_status = gr.Textbox(label="Model Status", interactive=False) with gr.Tabs(): # Text-only chat tab with gr.Tab("💬 Text Chat", id="text_chat"): gr.Markdown("### Medical Text Consultation") with gr.Row(): with gr.Column(scale=3): text_system = gr.Textbox( value="You are a helpful medical assistant.", label="System Instruction", placeholder="Set the AI's role and behavior..." ) chatbot_text = gr.Chatbot( height=400, placeholder="Start a medical conversation...", label="Medical Assistant" ) with gr.Row(): text_input = gr.Textbox( placeholder="Ask a medical question...", label="Your Question", scale=4 ) text_submit = gr.Button("Send", scale=1) with gr.Column(scale=1): gr.Markdown(""" ### 💡 Example Questions: - How do you differentiate bacterial from viral pneumonia? - What are the symptoms of diabetes? - Explain the mechanism of action of ACE inhibitors - What are the contraindications for MRI? """) # Image analysis tab with gr.Tab("🖼️ Image Analysis", id="image_analysis"): gr.Markdown("### Medical Image Analysis") with gr.Row(): with gr.Column(scale=2): image_input = gr.Image( type="pil", label="Upload Medical Image", height=300 ) image_system = gr.Textbox( value="You are an expert radiologist.", label="System Instruction" ) image_text_input = gr.Textbox( value="Describe this X-ray", label="Question about the image", placeholder="What would you like to know about this image?" ) image_submit = gr.Button("🔍 Analyze Image", variant="primary") with gr.Column(scale=2): image_output = gr.Textbox( label="Analysis Result", lines=15, placeholder="Upload an image and click 'Analyze Image' to see the AI's analysis..." ) # Event handlers load_btn.click( fn=app.load_model, inputs=[model_choice, use_quantization], outputs=[model_status] ) def respond_text(message, history, system_instruction): if message.strip() == "": return history, "" response = app.chat_text_only(message, history, system_instruction) history.append((message, response)) return history, "" text_submit.click( fn=respond_text, inputs=[text_input, chatbot_text, text_system], outputs=[chatbot_text, text_input] ) text_input.submit( fn=respond_text, inputs=[text_input, chatbot_text, text_system], outputs=[chatbot_text, text_input] ) image_submit.click( fn=app.chat_with_image, inputs=[image_text_input, image_input, image_system], outputs=[image_output] ) # Example image loading gr.Markdown(""" --- ### 📚 About MedGemma MedGemma is a collection of Gemma variants trained for medical applications. Learn more at the [HAI-DEF developer site](https://developers.google.com/health-ai-developer-foundations/medgemma). **Disclaimer**: This tool is for educational and research purposes only. Always consult qualified healthcare professionals for medical advice. """) if __name__ == "__main__": demo.launch()