lumenex / app.py
walaa2022's picture
Update app.py
ad07a2e verified
raw
history blame
11.7 kB
import gradio as gr
import torch
from transformers import (
AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoTokenizer,
AutoProcessor,
pipeline
)
from PIL import Image
import os
import spaces
# Try to import bitsandbytes for quantization (optional)
try:
from transformers import BitsAndBytesConfig
QUANTIZATION_AVAILABLE = True
except ImportError:
QUANTIZATION_AVAILABLE = False
print("⚠️ bitsandbytes not available. Quantization will be disabled.")
# Configuration
MODEL_4B = "google/medgemma-4b-it"
MODEL_27B = "google/medgemma-27b-text-it"
class MedGemmaApp:
def __init__(self):
self.current_model = None
self.current_tokenizer = None
self.current_processor = None
self.current_pipe = None
self.model_type = None
def get_model_kwargs(self, use_quantization=True):
"""Get model configuration arguments"""
model_kwargs = {
"torch_dtype": torch.bfloat16,
"device_map": "auto",
}
# Only add quantization if available and requested
if use_quantization and QUANTIZATION_AVAILABLE:
model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)
elif use_quantization and not QUANTIZATION_AVAILABLE:
print("⚠️ Quantization requested but bitsandbytes not available. Loading without quantization.")
return model_kwargs
@spaces.GPU
def load_model(self, model_choice, use_quantization=True):
"""Load the selected model"""
try:
model_id = MODEL_4B if model_choice == "4B (Multimodal)" else MODEL_27B
model_kwargs = self.get_model_kwargs(use_quantization)
# Clear previous model
if self.current_model is not None:
del self.current_model
del self.current_tokenizer
if self.current_processor:
del self.current_processor
if self.current_pipe:
del self.current_pipe
torch.cuda.empty_cache()
if model_choice == "4B (Multimodal)":
# Load multimodal model
self.current_model = AutoModelForImageTextToText.from_pretrained(
model_id, **model_kwargs
)
self.current_processor = AutoProcessor.from_pretrained(model_id)
self.model_type = "multimodal"
# Create pipeline for easier inference
self.current_pipe = pipeline(
"image-text-to-text",
model=self.current_model,
processor=self.current_processor,
)
self.current_pipe.model.generation_config.do_sample = False
else:
# Load text-only model
self.current_model = AutoModelForCausalLM.from_pretrained(
model_id, **model_kwargs
)
self.current_tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model_type = "text"
# Create pipeline for easier inference
self.current_pipe = pipeline(
"text-generation",
model=self.current_model,
tokenizer=self.current_tokenizer,
)
self.current_pipe.model.generation_config.do_sample = False
return f"βœ… Successfully loaded {model_choice} model!"
except Exception as e:
return f"❌ Error loading model: {str(e)}"
@spaces.GPU
def chat_text_only(self, message, history, system_instruction="You are a helpful medical assistant."):
"""Handle text-only conversations"""
if self.current_model is None or self.model_type != "text":
return "Please load the 27B (Text Only) model first!"
try:
messages = [
{"role": "system", "content": system_instruction},
{"role": "user", "content": message}
]
# Add conversation history
for human, assistant in history:
messages.insert(-1, {"role": "user", "content": human})
messages.insert(-1, {"role": "assistant", "content": assistant})
output = self.current_pipe(messages, max_new_tokens=500)
response = output[0]["generated_text"][-1]["content"]
return response
except Exception as e:
return f"Error generating response: {str(e)}"
@spaces.GPU
def chat_with_image(self, message, image, system_instruction="You are an expert radiologist."):
"""Handle image + text conversations"""
if self.current_model is None or self.model_type != "multimodal":
return "Please load the 4B (Multimodal) model first!"
if image is None:
return "Please upload an image to analyze."
try:
messages = [
{
"role": "system",
"content": [{"type": "text", "text": system_instruction}]
},
{
"role": "user",
"content": [
{"type": "text", "text": message},
{"type": "image", "image": image}
]
}
]
output = self.current_pipe(text=messages, max_new_tokens=300)
response = output[0]["generated_text"][-1]["content"]
return response
except Exception as e:
return f"Error analyzing image: {str(e)}"
# Initialize the app
app = MedGemmaApp()
# Create Gradio interface
with gr.Blocks(title="MedGemma Medical AI Assistant", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# πŸ₯ MedGemma Medical AI Assistant
Welcome to MedGemma, Google's medical AI assistant! Choose between:
- **4B Multimodal**: Analyze medical images (X-rays, scans) with text
- **27B Text-Only**: Advanced medical text conversations
> **Note**: This is for educational and research purposes only. Always consult healthcare professionals for medical advice.
""")
with gr.Row():
with gr.Column(scale=1):
model_choice = gr.Radio(
choices=["4B (Multimodal)", "27B (Text Only)"],
value="4B (Multimodal)",
label="Select Model",
info="4B supports images, 27B is text-only but more powerful"
)
use_quantization = gr.Checkbox(
value=QUANTIZATION_AVAILABLE,
label="Use 4-bit Quantization" + ("" if QUANTIZATION_AVAILABLE else " (Unavailable)"),
info="Reduces memory usage" + ("" if QUANTIZATION_AVAILABLE else " - bitsandbytes not installed"),
interactive=QUANTIZATION_AVAILABLE
)
load_btn = gr.Button("πŸš€ Load Model", variant="primary")
model_status = gr.Textbox(label="Model Status", interactive=False)
with gr.Tabs():
# Text-only chat tab
with gr.Tab("πŸ’¬ Text Chat", id="text_chat"):
gr.Markdown("### Medical Text Consultation")
with gr.Row():
with gr.Column(scale=3):
text_system = gr.Textbox(
value="You are a helpful medical assistant.",
label="System Instruction",
placeholder="Set the AI's role and behavior..."
)
chatbot_text = gr.Chatbot(
height=400,
placeholder="Start a medical conversation...",
label="Medical Assistant"
)
with gr.Row():
text_input = gr.Textbox(
placeholder="Ask a medical question...",
label="Your Question",
scale=4
)
text_submit = gr.Button("Send", scale=1)
with gr.Column(scale=1):
gr.Markdown("""
### πŸ’‘ Example Questions:
- How do you differentiate bacterial from viral pneumonia?
- What are the symptoms of diabetes?
- Explain the mechanism of action of ACE inhibitors
- What are the contraindications for MRI?
""")
# Image analysis tab
with gr.Tab("πŸ–ΌοΈ Image Analysis", id="image_analysis"):
gr.Markdown("### Medical Image Analysis")
with gr.Row():
with gr.Column(scale=2):
image_input = gr.Image(
type="pil",
label="Upload Medical Image",
height=300
)
image_system = gr.Textbox(
value="You are an expert radiologist.",
label="System Instruction"
)
image_text_input = gr.Textbox(
value="Describe this X-ray",
label="Question about the image",
placeholder="What would you like to know about this image?"
)
image_submit = gr.Button("πŸ” Analyze Image", variant="primary")
with gr.Column(scale=2):
image_output = gr.Textbox(
label="Analysis Result",
lines=15,
placeholder="Upload an image and click 'Analyze Image' to see the AI's analysis..."
)
# Event handlers
load_btn.click(
fn=app.load_model,
inputs=[model_choice, use_quantization],
outputs=[model_status]
)
def respond_text(message, history, system_instruction):
if message.strip() == "":
return history, ""
response = app.chat_text_only(message, history, system_instruction)
history.append((message, response))
return history, ""
text_submit.click(
fn=respond_text,
inputs=[text_input, chatbot_text, text_system],
outputs=[chatbot_text, text_input]
)
text_input.submit(
fn=respond_text,
inputs=[text_input, chatbot_text, text_system],
outputs=[chatbot_text, text_input]
)
image_submit.click(
fn=app.chat_with_image,
inputs=[image_text_input, image_input, image_system],
outputs=[image_output]
)
# Example image loading
gr.Markdown("""
---
### πŸ“š About MedGemma
MedGemma is a collection of Gemma variants trained for medical applications.
Learn more at the [HAI-DEF developer site](https://developers.google.com/health-ai-developer-foundations/medgemma).
**Disclaimer**: This tool is for educational and research purposes only.
Always consult qualified healthcare professionals for medical advice.
""")
if __name__ == "__main__":
demo.launch()