File size: 4,093 Bytes
4b7893b d938a07 24337c4 d938a07 24337c4 8a2c4ea d938a07 24337c4 d938a07 fe9f836 4b7893b d938a07 1caab63 4cc5acb d938a07 47f40b7 d938a07 24337c4 29eec6c 24337c4 47f40b7 24337c4 d938a07 24337c4 47f40b7 24337c4 d938a07 24337c4 d938a07 47f40b7 d938a07 24337c4 d938a07 24337c4 d938a07 24337c4 d938a07 24337c4 47f40b7 24337c4 47f40b7 24337c4 d938a07 24337c4 d938a07 24337c4 d938a07 24337c4 d938a07 24337c4 d938a07 24337c4 d938a07 47f40b7 24337c4 47f40b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import os
import streamlit as st
from together import Together
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
# --- Configuration ---
# TogetherAI API key (env var name pilotikval)
TOGETHER_API_KEY = os.environ.get("pilotikval")
if not TOGETHER_API_KEY:
st.error("Missing pilotikval environment variable.")
st.stop()
# Initialize TogetherAI client
client = Together(api_key=TOGETHER_API_KEY)
# Embeddings setup
EMBED_MODEL_NAME = "BAAI/bge-base-en"
embeddings = HuggingFaceBgeEmbeddings(
model_name=EMBED_MODEL_NAME,
encode_kwargs={"normalize_embeddings": True},
)
# Sidebar: select collection
st.sidebar.title("DocChatter RAG")
collection = st.sidebar.selectbox(
"Choose a document collection:",
['General Medicine', 'RespiratoryFishman', 'RespiratoryMurray', 'MedMRCP2', 'OldMedicine']
)
dirs = {
'General Medicine': './oxfordmedbookdir/',
'RespiratoryFishman': './respfishmandbcud/',
'RespiratoryMurray': './respmurray/',
'MedMRCP2': './medmrcp2store/',
'OldMedicine': './mrcpchromadb/'
}
cols = {
'General Medicine': 'oxfordmed',
'RespiratoryFishman': 'fishmannotescud',
'RespiratoryMurray': 'respmurraynotes',
'MedMRCP2': 'medmrcp2notes',
'OldMedicine': 'mrcppassmednotes'
}
persist_directory = dirs[collection]
collection_name = cols[collection]
# Load Chroma vector store
vectorstore = Chroma(
collection_name=collection_name,
persist_directory=persist_directory,
embedding_function=embeddings
)
retriever = vectorstore.as_retriever(search_kwargs={"k": 20}) # k=20
# System prompt template with instruction for detailed long answers
def build_system(context: str) -> dict:
return {
"role": "system",
"content": (
"You are an expert medical assistant. Provide a thorough, detailed, and complete answer. "
"If you don't know, say you don't know.\n"
"Use the following context from medical docs to answer.\n\n"
"Context:\n" + context
)
}
st.title("🩺 DocChatter RAG (Streaming & Memory)")
# Initialize chat history
if 'chat_history' not in st.session_state:
st.session_state.chat_history = [] # list of dicts {role, content}
# Get user input at top level
user_prompt = st.chat_input("Ask anything about your docs…")
# Tabs for UI
chat_tab, clear_tab = st.tabs(["Chat", "Clear History"])
with chat_tab:
# Display existing chat
for msg in st.session_state.chat_history:
st.chat_message(msg['role']).write(msg['content'])
# On new input
if user_prompt:
# Echo user
st.chat_message("user").write(user_prompt)
st.session_state.chat_history.append({"role": "user", "content": user_prompt})
# Retrieve top-k docs
docs = retriever.get_relevant_documents(user_prompt)
context = "\n---\n".join([d.page_content for d in docs])
# Build message sequence: system + full history
messages = [build_system(context)]
for m in st.session_state.chat_history:
messages.append(m)
# Prepare streaming response
response_container = st.chat_message("assistant")
stream_placeholder = response_container.empty()
answer = ""
# Stream tokens
for token in client.chat.completions.create(
model="meta-llama/Llama-4-Scout-17B-16E-Instruct",
messages=messages,
max_tokens=22048,
temperature=0.1,
stream=True
):
if hasattr(token, 'choices') and token.choices[0].delta.content:
delta = token.choices[0].delta.content
answer += delta
stream_placeholder.write(answer)
# Save assistant response
st.session_state.chat_history.append({"role": "assistant", "content": answer})
with clear_tab:
if st.button("🗑️ Clear chat history"):
st.session_state.chat_history = []
st.experimental_rerun()
# (Optional) persist new docs
# vectorstore.persist()
|