|
import os |
|
import streamlit as st |
|
from together import Together |
|
from langchain_community.vectorstores import Chroma |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
|
|
|
|
|
|
TOGETHER_API_KEY = os.environ.get("pilotikval") |
|
if not TOGETHER_API_KEY: |
|
st.error("Missing pilotikval environment variable.") |
|
st.stop() |
|
|
|
|
|
client = Together(api_key=TOGETHER_API_KEY) |
|
|
|
|
|
EMBED_MODEL_NAME = "BAAI/bge-base-en" |
|
embeddings = HuggingFaceEmbeddings( |
|
model_name=EMBED_MODEL_NAME, |
|
encode_kwargs={"normalize_embeddings": True}, |
|
) |
|
|
|
|
|
st.sidebar.title("DocChatter RAG") |
|
collection = st.sidebar.selectbox( |
|
"Choose a document collection:", |
|
['General Medicine', 'RespiratoryFishman', 'RespiratoryMurray', 'MedMRCP2', 'OldMedicine'] |
|
) |
|
|
|
dirs = { |
|
'General Medicine': './oxfordmedbookdir/', |
|
'RespiratoryFishman': './respfishmandbcud/', |
|
'RespiratoryMurray': './respmurray/', |
|
'MedMRCP2': './medmrcp2store/', |
|
'OldMedicine': './mrcpchromadb/' |
|
} |
|
cols = { |
|
'General Medicine': 'oxfordmed', |
|
'RespiratoryFishman': 'fishmannotescud', |
|
'RespiratoryMurray': 'respmurraynotes', |
|
'MedMRCP2': 'medmrcp2notes', |
|
'OldMedicine': 'mrcppassmednotes' |
|
} |
|
|
|
persist_directory = dirs[collection] |
|
collection_name = cols[collection] |
|
|
|
|
|
vectorstore = Chroma( |
|
collection_name=collection_name, |
|
persist_directory=persist_directory, |
|
embedding_function=embeddings |
|
) |
|
retriever = vectorstore.as_retriever(search_kwargs={"k": 20}) |
|
|
|
|
|
|
|
def build_system(context: str) -> dict: |
|
""" |
|
Build a comprehensive system prompt: |
|
- Act as an expert medical assistant and attentive listener. |
|
- Leverage retrieved context to craft detailed, accurate, and empathetic responses. |
|
- Ask clarifying follow-up questions if the user's query is ambiguous. |
|
- Structure answers clearly with headings, bullet points, and step-by-step explanations. |
|
- Cite relevant context sections when appropriate. |
|
- Maintain conversational memory for follow-up continuity. |
|
""" |
|
prompt = f""" |
|
You are a world-class medical assistant and conversational partner. |
|
|
|
Listen carefully to the user’s questions, reference the context below, and provide a thorough, evidence-based response. |
|
If any part of the question is unclear, ask a clarifying question before proceeding. |
|
Organize your answer with clear headings or bullet points, and refer back to specific context snippets as needed. |
|
Always be empathetic, concise, and precise in your medical explanations. |
|
Retain memory of previous user messages to support follow-up interactions. |
|
|
|
=== Retrieved Context Start === |
|
{context} |
|
=== Retrieved Context End === |
|
""" |
|
return {"role": "system", "content": prompt} |
|
|
|
st.title("🩺 DocChatter RAG (Streaming & Memory)") |
|
|
|
|
|
if 'chat_history' not in st.session_state: |
|
st.session_state.chat_history = [] |
|
|
|
|
|
user_prompt = st.chat_input("Ask anything about your docs…") |
|
|
|
|
|
chat_tab, clear_tab = st.tabs(["Chat", "Clear History"]) |
|
|
|
with chat_tab: |
|
|
|
for msg in st.session_state.chat_history: |
|
st.chat_message(msg['role']).write(msg['content']) |
|
|
|
|
|
if user_prompt: |
|
|
|
st.chat_message("user").write(user_prompt) |
|
st.session_state.chat_history.append({"role": "user", "content": user_prompt}) |
|
|
|
|
|
try: |
|
docs = retriever.invoke({"query": user_prompt}) |
|
except Exception: |
|
docs = retriever.get_relevant_documents(user_prompt) |
|
context = "\n---\n".join([d.page_content for d in docs]) |
|
|
|
|
|
messages = [build_system(context)] |
|
for m in st.session_state.chat_history: |
|
messages.append(m) |
|
|
|
|
|
response_container = st.chat_message("assistant") |
|
stream_placeholder = response_container.empty() |
|
answer = "" |
|
|
|
for token in client.chat.completions.create( |
|
model="meta-llama/Llama-4-Scout-17B-16E-Instruct", |
|
messages=messages, |
|
max_tokens=22048, |
|
temperature=0.1, |
|
stream=True |
|
): |
|
try: |
|
choice = token.choices[0] |
|
delta = getattr(choice.delta, 'content', '') |
|
if delta: |
|
answer += delta |
|
stream_placeholder.write(answer) |
|
except (IndexError, AttributeError): |
|
continue |
|
|
|
|
|
st.session_state.chat_history.append({"role": "assistant", "content": answer}) |
|
|
|
with clear_tab: |
|
if st.button("🗑️ Clear chat history"): |
|
st.session_state.chat_history = [] |
|
st.experimental_rerun() |
|
|
|
|
|
|
|
|