medrax.org / interface.py
oldcai's picture
Upload folder using huggingface_hub
d7a7846 verified
import re
import gradio as gr
from pathlib import Path
import time
import shutil
from typing import AsyncGenerator, List, Optional, Tuple
from gradio import ChatMessage
class ChatInterface:
"""
A chat interface for interacting with a medical AI agent through Gradio.
Handles file uploads, message processing, and chat history management.
Supports both regular image files and DICOM medical imaging files.
"""
def __init__(self, agent, tools_dict):
"""
Initialize the chat interface.
Args:
agent: The medical AI agent to handle requests
tools_dict (dict): Dictionary of available tools for image processing
"""
self.agent = agent
self.tools_dict = tools_dict
self.upload_dir = Path("temp")
self.upload_dir.mkdir(exist_ok=True)
self.current_thread_id = None
# Separate storage for original and display paths
self.original_file_path = None # For LLM (.dcm or other)
self.display_file_path = None # For UI (always viewable format)
def handle_upload(self, file_path: str) -> str:
"""
Handle new file upload and set appropriate paths.
Args:
file_path (str): Path to the uploaded file
Returns:
str: Display path for UI, or None if no file uploaded
"""
if not file_path:
return None
source = Path(file_path)
timestamp = int(time.time())
# Save original file with proper suffix
suffix = source.suffix.lower()
saved_path = self.upload_dir / f"upload_{timestamp}{suffix}"
shutil.copy2(file_path, saved_path) # Use file_path directly instead of source
self.original_file_path = str(saved_path)
# Handle DICOM conversion for display only
if suffix == ".dcm":
output, _ = self.tools_dict["DicomProcessorTool"]._run(str(saved_path))
self.display_file_path = output["image_path"]
else:
self.display_file_path = str(saved_path)
return self.display_file_path
def add_message(
self, message: str, display_image: str, history: List[dict]
) -> Tuple[List[dict], gr.Textbox]:
"""
Add a new message to the chat history.
Args:
message (str): Text message to add
display_image (str): Path to image being displayed
history (List[dict]): Current chat history
Returns:
Tuple[List[dict], gr.Textbox]: Updated history and textbox component
"""
image_path = self.original_file_path or display_image
if image_path is not None:
history.append({"role": "user", "content": {"path": image_path}})
if message is not None:
history.append({"role": "user", "content": message})
return history, gr.Textbox(value=message, interactive=False)
async def process_message(
self, message: str, display_image: Optional[str], chat_history: List[ChatMessage]
) -> AsyncGenerator[Tuple[List[ChatMessage], Optional[str], str], None]:
"""
Process a message and generate responses.
Args:
message (str): User message to process
display_image (Optional[str]): Path to currently displayed image
chat_history (List[ChatMessage]): Current chat history
Yields:
Tuple[List[ChatMessage], Optional[str], str]: Updated chat history, display path, and empty string
"""
chat_history = chat_history or []
# Initialize thread if needed
if not self.current_thread_id:
self.current_thread_id = str(time.time())
messages = []
image_path = self.original_file_path or display_image
if image_path is not None:
messages.append({"role": "user", "content": f"path: {image_path}"})
if message is not None:
messages.append({"role": "user", "content": message})
try:
for event in self.agent.workflow.stream(
{"messages": messages}, {"configurable": {"thread_id": self.current_thread_id}}
):
if isinstance(event, dict):
if "process" in event:
content = event["process"]["messages"][-1].content
if content:
content = re.sub(r"temp/[^\s]*", "", content)
chat_history.append(ChatMessage(role="assistant", content=content))
yield chat_history, self.display_file_path, ""
elif "execute" in event:
for message in event["execute"]["messages"]:
tool_name = message.name
tool_result = eval(message.content)[0]
if tool_result:
metadata = {"title": f"πŸ–ΌοΈ Image from tool: {tool_name}"}
formatted_result = " ".join(
line.strip() for line in str(tool_result).splitlines()
).strip()
metadata["description"] = formatted_result
chat_history.append(
ChatMessage(
role="assistant",
content=formatted_result,
metadata=metadata,
)
)
# For image_visualizer, use display path
if tool_name == "image_visualizer":
self.display_file_path = tool_result["image_path"]
chat_history.append(
ChatMessage(
role="assistant",
# content=gr.Image(value=self.display_file_path),
content={"path": self.display_file_path},
)
)
yield chat_history, self.display_file_path, ""
except Exception as e:
chat_history.append(
ChatMessage(
role="assistant", content=f"❌ Error: {str(e)}", metadata={"title": "Error"}
)
)
yield chat_history, self.display_file_path
def create_demo(agent, tools_dict):
"""
Create a Gradio demo interface for the medical AI agent.
Args:
agent: The medical AI agent to handle requests
tools_dict (dict): Dictionary of available tools for image processing
Returns:
gr.Blocks: Gradio Blocks interface
"""
interface = ChatInterface(agent, tools_dict)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
with gr.Column():
gr.Markdown(
"""
# πŸ₯ MedRAX
Medical Reasoning Agent for Chest X-ray
"""
)
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(
[],
height=800,
container=True,
show_label=True,
elem_classes="chat-box",
type="messages",
label="Agent",
avatar_images=(
None,
"assets/medrax_logo.jpg",
),
)
with gr.Row():
with gr.Column(scale=3):
txt = gr.Textbox(
show_label=False,
placeholder="Ask about the X-ray...",
container=False,
)
with gr.Column(scale=3):
image_display = gr.Image(
label="Image", type="filepath", height=700, container=True
)
with gr.Row():
upload_button = gr.UploadButton(
"πŸ“Ž Upload X-Ray",
file_types=["image"],
)
dicom_upload = gr.UploadButton(
"πŸ“„ Upload DICOM",
file_types=["file"],
)
with gr.Row():
clear_btn = gr.Button("Clear Chat")
new_thread_btn = gr.Button("New Thread")
# Event handlers
def clear_chat():
interface.original_file_path = None
interface.display_file_path = None
return [], None
def new_thread():
interface.current_thread_id = str(time.time())
return [], interface.display_file_path
def handle_file_upload(file):
return interface.handle_upload(file.name)
chat_msg = txt.submit(
interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt]
)
bot_msg = chat_msg.then(
interface.process_message,
inputs=[txt, image_display, chatbot],
outputs=[chatbot, image_display, txt],
)
bot_msg.then(lambda: gr.Textbox(interactive=True), None, [txt])
upload_button.upload(handle_file_upload, inputs=upload_button, outputs=image_display)
dicom_upload.upload(handle_file_upload, inputs=dicom_upload, outputs=image_display)
clear_btn.click(clear_chat, outputs=[chatbot, image_display])
new_thread_btn.click(new_thread, outputs=[chatbot, image_display])
return demo