File size: 2,764 Bytes
fadf40f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import streamlit as st
import weave
from dotenv import load_dotenv

from medrag_multi_modal.assistant import (
    FigureAnnotatorFromPageImage,
    LLMClient,
    MedQAAssistant,
)
from medrag_multi_modal.retrieval import MedCPTRetriever

# Load environment variables
load_dotenv()

# Sidebar for configuration settings
st.sidebar.title("Configuration Settings")
project_name = st.sidebar.text_input(
    "Project Name",
    "ml-colabs/medrag-multi-modal"
)
chunk_dataset_name = st.sidebar.text_input(
    "Text Chunk WandB Dataset Name",
    "grays-anatomy-chunks:v0"
)
index_artifact_address = st.sidebar.text_input(
    "WandB Index Artifact Address",
    "ml-colabs/medrag-multi-modal/grays-anatomy-medcpt:v0",
)
image_artifact_address = st.sidebar.text_input(
    "WandB Image Artifact Address",
    "ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6",
)
llm_model_name = st.sidebar.text_input(
    "LLM Client Model Name",
    "gemini-1.5-flash"
)
figure_extraction_model_name = st.sidebar.text_input(
    "Figure Extraction Model Name",
    "pixtral-12b-2409"
)
structured_output_model_name = st.sidebar.text_input(
    "Structured Output Model Name",
    "gpt-4o"
)

# Initialize Weave
weave.init(project_name=project_name)

# Initialize clients and assistants
llm_client = LLMClient(model_name=llm_model_name)
retriever = MedCPTRetriever.from_wandb_artifact(
    chunk_dataset_name=chunk_dataset_name,
    index_artifact_address=index_artifact_address,
)
figure_annotator = FigureAnnotatorFromPageImage(
    figure_extraction_llm_client=LLMClient(model_name=figure_extraction_model_name),
    structured_output_llm_client=LLMClient(model_name=structured_output_model_name),
    image_artifact_address=image_artifact_address,
)
medqa_assistant = MedQAAssistant(
    llm_client=llm_client, retriever=retriever, figure_annotator=figure_annotator
)

# Streamlit app layout
st.title("MedQA Assistant App")

# Initialize chat history
if "chat_history" not in st.session_state:
    st.session_state.chat_history = []

# Display chat messages from history on app rerun
for message in st.session_state.chat_history:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

# Chat thread section with user input and response
if query := st.chat_input("What medical question can I assist you with today?"):
    # Add user message to chat history
    st.session_state.chat_history.append({"role": "user", "content": query})
    with st.chat_message("user"):
        st.markdown(query)

    # Process query and get response
    response = medqa_assistant.predict(query=query)
    st.session_state.chat_history.append({"role": "assistant", "content": response})
    with st.chat_message("assistant"):
        st.markdown(response)