Spaces:
Running
Running
File size: 3,509 Bytes
fadf40f 170d9a9 0dae114 fadf40f 0dae114 170d9a9 0dae114 fadf40f 0dae114 fadf40f 170d9a9 fadf40f 170d9a9 fadf40f 170d9a9 fadf40f 170d9a9 fadf40f 170d9a9 fadf40f 170d9a9 fadf40f 170d9a9 076d575 170d9a9 fadf40f 170d9a9 716efd2 170d9a9 716efd2 170d9a9 716efd2 170d9a9 fadf40f 170d9a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import streamlit as st
from medrag_multi_modal.assistant import LLMClient, MedQAAssistant
from medrag_multi_modal.retrieval.text_retrieval import (
BM25sRetriever,
ContrieverRetriever,
MedCPTRetriever,
NVEmbed2Retriever,
)
# Define constants
ALL_AVAILABLE_MODELS = [
"gemini-1.5-flash-latest",
"gemini-1.5-pro-latest",
"gpt-4o",
"gpt-4o-mini",
]
# Sidebar for configuration settings
st.sidebar.title("Configuration Settings")
project_name = st.sidebar.text_input(
label="Project Name",
value="ml-colabs/medrag-multi-modal",
placeholder="wandb project name",
help="format: wandb_username/wandb_project_name",
)
chunk_dataset_id = st.sidebar.selectbox(
label="Chunk Dataset ID",
options=["ashwiniai/medrag-text-corpus-chunks"],
)
llm_model = st.sidebar.selectbox(
label="LLM Model",
options=ALL_AVAILABLE_MODELS,
)
top_k_chunks_for_query = st.sidebar.slider(
label="Top K Chunks for Query",
min_value=1,
max_value=20,
value=5,
)
top_k_chunks_for_options = st.sidebar.slider(
label="Top K Chunks for Options",
min_value=1,
max_value=20,
value=3,
)
rely_only_on_context = st.sidebar.checkbox(
label="Rely Only on Context",
value=False,
)
retriever_type = st.sidebar.selectbox(
label="Retriever Type",
options=[
"",
"BM25S",
"Contriever",
"MedCPT",
"NV-Embed-v2",
],
)
if retriever_type != "":
llm_model = LLMClient(model_name=llm_model)
retriever = None
if retriever_type == "BM25S":
retriever = BM25sRetriever.from_index(
index_repo_id="ashwiniai/medrag-text-corpus-chunks-bm25s"
)
elif retriever_type == "Contriever":
retriever = ContrieverRetriever.from_index(
index_repo_id="ashwiniai/medrag-text-corpus-chunks-contriever",
chunk_dataset=chunk_dataset_id,
)
elif retriever_type == "MedCPT":
retriever = MedCPTRetriever.from_index(
index_repo_id="ashwiniai/medrag-text-corpus-chunks-medcpt",
chunk_dataset=chunk_dataset_id,
)
elif retriever_type == "NV-Embed-v2":
retriever = NVEmbed2Retriever.from_index(
index_repo_id="ashwiniai/medrag-text-corpus-chunks-nv-embed-2",
chunk_dataset=chunk_dataset_id,
)
medqa_assistant = MedQAAssistant(
llm_client=llm_model,
retriever=retriever,
top_k_chunks_for_query=top_k_chunks_for_query,
top_k_chunks_for_options=top_k_chunks_for_options,
)
with st.chat_message("assistant"):
st.markdown(
"""
Hi! I am Medrag, your medical assistant. You can ask me any questions about the medical and the life sciences.
I am currently a work-in-progress, so please bear with my stupidity and overall lack of knowledge.
**Note:** that I am not a medical professional, so please do not rely on my answers for medical decisions.
Please consult a medical professional for any medical advice.
In order to learn more about how I am being developed, please visit [soumik12345/medrag-multi-modal](https://github.com/soumik12345/medrag-multi-modal).
""",
unsafe_allow_html=True,
)
query = st.chat_input("Enter your question here")
if query:
with st.chat_message("user"):
st.markdown(query)
response = medqa_assistant.predict(query=query)
with st.chat_message("assistant"):
st.markdown(response.response)
|