File size: 4,093 Bytes
4b7893b
d938a07
 
24337c4
 
d938a07
 
24337c4
8a2c4ea
d938a07
24337c4
d938a07
 
 
 
 
 
 
 
 
 
fe9f836
4b7893b
d938a07
 
 
 
 
1caab63
4cc5acb
d938a07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47f40b7
 
d938a07
 
 
 
 
 
 
24337c4
29eec6c
24337c4
47f40b7
24337c4
 
 
 
 
 
 
 
 
 
 
d938a07
 
 
24337c4
47f40b7
24337c4
 
d938a07
24337c4
d938a07
47f40b7
d938a07
24337c4
d938a07
 
 
24337c4
 
 
 
 
d938a07
24337c4
 
d938a07
 
24337c4
47f40b7
 
24337c4
47f40b7
24337c4
d938a07
24337c4
d938a07
24337c4
 
d938a07
 
 
24337c4
 
d938a07
 
24337c4
d938a07
 
24337c4
 
 
d938a07
 
 
 
 
 
47f40b7
24337c4
47f40b7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import streamlit as st
from together import Together
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceBgeEmbeddings

# --- Configuration ---
# TogetherAI API key (env var name pilotikval)
TOGETHER_API_KEY = os.environ.get("pilotikval")
if not TOGETHER_API_KEY:
    st.error("Missing pilotikval environment variable.")
    st.stop()

# Initialize TogetherAI client
client = Together(api_key=TOGETHER_API_KEY)

# Embeddings setup
EMBED_MODEL_NAME = "BAAI/bge-base-en"
embeddings = HuggingFaceBgeEmbeddings(
    model_name=EMBED_MODEL_NAME,
    encode_kwargs={"normalize_embeddings": True},
)

# Sidebar: select collection
st.sidebar.title("DocChatter RAG")
collection = st.sidebar.selectbox(
    "Choose a document collection:",
    ['General Medicine', 'RespiratoryFishman', 'RespiratoryMurray', 'MedMRCP2', 'OldMedicine']
)

dirs = {
    'General Medicine': './oxfordmedbookdir/',
    'RespiratoryFishman': './respfishmandbcud/',
    'RespiratoryMurray': './respmurray/',
    'MedMRCP2': './medmrcp2store/',
    'OldMedicine': './mrcpchromadb/'
}
cols = {
    'General Medicine': 'oxfordmed',
    'RespiratoryFishman': 'fishmannotescud',
    'RespiratoryMurray': 'respmurraynotes',
    'MedMRCP2': 'medmrcp2notes',
    'OldMedicine': 'mrcppassmednotes'
}

persist_directory = dirs[collection]
collection_name = cols[collection]

# Load Chroma vector store
vectorstore = Chroma(
    collection_name=collection_name,
    persist_directory=persist_directory,
    embedding_function=embeddings
)
retriever = vectorstore.as_retriever(search_kwargs={"k": 20})  # k=20

# System prompt template with instruction for detailed long answers
def build_system(context: str) -> dict:
    return {
        "role": "system",
        "content": (
            "You are an expert medical assistant. Provide a thorough, detailed, and complete answer. "
            "If you don't know, say you don't know.\n"
            "Use the following context from medical docs to answer.\n\n"
            "Context:\n" + context
        )
    }

st.title("🩺 DocChatter RAG (Streaming & Memory)")

# Initialize chat history
if 'chat_history' not in st.session_state:
    st.session_state.chat_history = []  # list of dicts {role, content}

# Get user input at top level
user_prompt = st.chat_input("Ask anything about your docs…")

# Tabs for UI
chat_tab, clear_tab = st.tabs(["Chat", "Clear History"])

with chat_tab:
    # Display existing chat
    for msg in st.session_state.chat_history:
        st.chat_message(msg['role']).write(msg['content'])

    # On new input
    if user_prompt:
        # Echo user
        st.chat_message("user").write(user_prompt)
        st.session_state.chat_history.append({"role": "user", "content": user_prompt})

        # Retrieve top-k docs
        docs = retriever.get_relevant_documents(user_prompt)
        context = "\n---\n".join([d.page_content for d in docs])

        # Build message sequence: system + full history
        messages = [build_system(context)]
        for m in st.session_state.chat_history:
            messages.append(m)

        # Prepare streaming response
        response_container = st.chat_message("assistant")
        stream_placeholder = response_container.empty()
        answer = ""

        # Stream tokens
        for token in client.chat.completions.create(
            model="meta-llama/Llama-4-Scout-17B-16E-Instruct",
            messages=messages,
            max_tokens=22048,
            temperature=0.1,
            stream=True
        ):
            if hasattr(token, 'choices') and token.choices[0].delta.content:
                delta = token.choices[0].delta.content
                answer += delta
                stream_placeholder.write(answer)

        # Save assistant response
        st.session_state.chat_history.append({"role": "assistant", "content": answer})

with clear_tab:
    if st.button("🗑️ Clear chat history"):
        st.session_state.chat_history = []
        st.experimental_rerun()

# (Optional) persist new docs
# vectorstore.persist()