File size: 3,509 Bytes
fadf40f
 
170d9a9
 
 
 
 
 
0dae114
fadf40f
0dae114
170d9a9
 
 
 
 
 
0dae114
fadf40f
 
 
0dae114
 
 
 
fadf40f
170d9a9
 
 
fadf40f
170d9a9
 
 
fadf40f
170d9a9
 
 
 
 
fadf40f
170d9a9
 
 
 
 
fadf40f
170d9a9
 
 
fadf40f
170d9a9
 
 
 
 
 
 
 
 
fadf40f
 
170d9a9
076d575
170d9a9
fadf40f
170d9a9
 
 
 
 
 
 
 
 
716efd2
170d9a9
 
 
 
716efd2
170d9a9
 
 
 
716efd2
170d9a9
 
 
 
 
 
 
 
fadf40f
 
170d9a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import streamlit as st

from medrag_multi_modal.assistant import LLMClient, MedQAAssistant
from medrag_multi_modal.retrieval.text_retrieval import (
    BM25sRetriever,
    ContrieverRetriever,
    MedCPTRetriever,
    NVEmbed2Retriever,
)

# Define constants
ALL_AVAILABLE_MODELS = [
    "gemini-1.5-flash-latest",
    "gemini-1.5-pro-latest",
    "gpt-4o",
    "gpt-4o-mini",
]

# Sidebar for configuration settings
st.sidebar.title("Configuration Settings")
project_name = st.sidebar.text_input(
    label="Project Name",
    value="ml-colabs/medrag-multi-modal",
    placeholder="wandb project name",
    help="format: wandb_username/wandb_project_name",
)
chunk_dataset_id = st.sidebar.selectbox(
    label="Chunk Dataset ID",
    options=["ashwiniai/medrag-text-corpus-chunks"],
)
llm_model = st.sidebar.selectbox(
    label="LLM Model",
    options=ALL_AVAILABLE_MODELS,
)
top_k_chunks_for_query = st.sidebar.slider(
    label="Top K Chunks for Query",
    min_value=1,
    max_value=20,
    value=5,
)
top_k_chunks_for_options = st.sidebar.slider(
    label="Top K Chunks for Options",
    min_value=1,
    max_value=20,
    value=3,
)
rely_only_on_context = st.sidebar.checkbox(
    label="Rely Only on Context",
    value=False,
)
retriever_type = st.sidebar.selectbox(
    label="Retriever Type",
    options=[
        "",
        "BM25S",
        "Contriever",
        "MedCPT",
        "NV-Embed-v2",
    ],
)

if retriever_type != "":

    llm_model = LLMClient(model_name=llm_model)

    retriever = None

    if retriever_type == "BM25S":
        retriever = BM25sRetriever.from_index(
            index_repo_id="ashwiniai/medrag-text-corpus-chunks-bm25s"
        )
    elif retriever_type == "Contriever":
        retriever = ContrieverRetriever.from_index(
            index_repo_id="ashwiniai/medrag-text-corpus-chunks-contriever",
            chunk_dataset=chunk_dataset_id,
        )
    elif retriever_type == "MedCPT":
        retriever = MedCPTRetriever.from_index(
            index_repo_id="ashwiniai/medrag-text-corpus-chunks-medcpt",
            chunk_dataset=chunk_dataset_id,
        )
    elif retriever_type == "NV-Embed-v2":
        retriever = NVEmbed2Retriever.from_index(
            index_repo_id="ashwiniai/medrag-text-corpus-chunks-nv-embed-2",
            chunk_dataset=chunk_dataset_id,
        )

    medqa_assistant = MedQAAssistant(
        llm_client=llm_model,
        retriever=retriever,
        top_k_chunks_for_query=top_k_chunks_for_query,
        top_k_chunks_for_options=top_k_chunks_for_options,
    )

    with st.chat_message("assistant"):
        st.markdown(
            """
Hi! I am Medrag, your medical assistant. You can ask me any questions about the medical and the life sciences.
I am currently a work-in-progress, so please bear with my stupidity and overall lack of knowledge.

**Note:** that I am not a medical professional, so please do not rely on my answers for medical decisions.
Please consult a medical professional for any medical advice.

In order to learn more about how I am being developed, please visit [soumik12345/medrag-multi-modal](https://github.com/soumik12345/medrag-multi-modal).
            """,
            unsafe_allow_html=True,
        )
    query = st.chat_input("Enter your question here")
    if query:
        with st.chat_message("user"):
            st.markdown(query)
        response = medqa_assistant.predict(query=query)
        with st.chat_message("assistant"):
            st.markdown(response.response)