File size: 4,761 Bytes
4b7893b
 
4bfe10b
 
ba25116
4bfe10b
 
15f0e8d
 
 
 
4b7893b
 
 
f7842a2
fe9f836
 
29eec6c
 
 
15f0e8d
29eec6c
 
fe9f836
f7842a2
17dad20
fe9f836
 
4b7893b
fe9f836
4bfe10b
fe9f836
f7842a2
17dad20
fe9f836
 
29eec6c
15f0e8d
fe9f836
4bfe10b
 
15f0e8d
 
 
 
 
 
 
 
 
 
 
 
 
 
4b7893b
15f0e8d
 
f7842a2
 
4b7893b
15f0e8d
 
 
 
4b7893b
15f0e8d
 
 
 
 
4b7893b
4bfe10b
29eec6c
4b7893b
 
29eec6c
fe9f836
 
29eec6c
f7842a2
15f0e8d
 
 
 
 
 
 
 
4bfe10b
15f0e8d
 
 
 
 
 
 
f7842a2
15f0e8d
 
 
f7842a2
4b7893b
15f0e8d
ba25116
29eec6c
4b7893b
29eec6c
4b7893b
15f0e8d
29eec6c
 
 
 
4b7893b
29eec6c
 
 
4b7893b
 
4bfe10b
 
 
 
 
 
 
 
 
 
4b7893b
f7842a2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import streamlit as st
import os
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain_together import Together
from langchain import hub
from operator import itemgetter
from langchain.schema import format_document
from langchain.prompts import ChatPromptTemplate, PromptTemplate
from langchain.memory import StreamlitChatMessageHistory, ConversationBufferMemory
from langchain_core.runnables import RunnableLambda, RunnableParallel, RunnablePassthrough

# Load the embedding function
model_name = "BAAI/bge-base-en"
encode_kwargs = {'normalize_embeddings': True}
embedding_function = HuggingFaceBgeEmbeddings(
    model_name=model_name,
    encode_kwargs=encode_kwargs
)

# Initialize the LLMs
llm = Together(
    model="mistralai/Mixtral-8x22B-Instruct-v0.1",
    temperature=0.2,
    max_new_tokens=22000,
    top_k=12,
    together_api_key=os.environ['pilotikval']
)

llmc = Together(
    model="mistralai/Mixtral-8x22B-Instruct-v0.1",
    temperature=0.2,
    max_new_tokens=1000,
    top_k=3,
    together_api_key=os.environ['pilotikval']
)

# Memory setup
msgs = StreamlitChatMessageHistory(key="langchain_messages")
memory = ConversationBufferMemory(chat_memory=msgs)

# Define the prompt templates
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(
    """Given the following conversation and a follow-up question, rephrase the follow-up question to be a standalone question.
    Chat History:
    {chat_history}
    Follow Up Input: {question}
    Standalone question:"""
)

ANSWER_PROMPT = ChatPromptTemplate.from_template(
    """You are helping a doctor. Answer based on the provided context:
    {context}
    Question: {question}"""
)

# Function to combine documents
def _combine_documents(docs, document_prompt=PromptTemplate.from_template("{page_content}"), document_separator="\n\n"):
    doc_strings = [format_document(doc, document_prompt) for doc in docs]
    return document_separator.join(doc_strings)

# Define the chain using LCEL
condense_question_chain = RunnableLambda(lambda x: {"chat_history": chistory, "question": x}) | CONDENSE_QUESTION_PROMPT | llmc
retriever_chain = RunnableLambda(lambda x: {"standalone_question": x}) | retriever | _combine_documents
answer_chain = ANSWER_PROMPT | llm

conversational_qa_chain = RunnableParallel(
    condense_question=condense_question_chain,
    retrieve=retriever_chain,
    generate_answer=answer_chain
)

# Define the Streamlit app
def app():
    with st.sidebar:
        st.title("dochatter")
        option = st.selectbox(
            'Which retriever would you like to use?',
            ('General Medicine', 'RespiratoryFishman', 'RespiratoryMurray', 'MedMRCP2', 'OldMedicine')
        )
        
        # Define retrievers based on option
        persist_directory = {
            'General Medicine': "./oxfordmedbookdir/",
            'RespiratoryFishman': "./respfishmandbcud/",
            'RespiratoryMurray': "./respmurray/",
            'MedMRCP2': "./medmrcp2store/",
            'OldMedicine': "./mrcpchromadb/"
        }.get(option, "./mrcpchromadb/")
        
        collection_name = {
            'General Medicine': "oxfordmed",
            'RespiratoryFishman': "fishmannotescud",
            'RespiratoryMurray': "respmurraynotes",
            'MedMRCP2': "medmrcp2notes",
            'OldMedicine': "mrcppassmednotes"
        }.get(option, "mrcppassmednotes")
        
        vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name=collection_name)
        retriever = vectordb.as_retriever(search_kwargs={"k": 5})
    
    if "messages" not in st.session_state:
        st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}]
    
    st.header("Ask Away!")
    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.write(message["content"])
            store_chat_history(message["role"], message["content"])
    
    prompts2 = st.chat_input("Say something")

    if prompts2:
        st.session_state.messages.append({"role": "user", "content": prompts2})
        with st.chat_message("user"):
            st.write(prompts2)

    if st.session_state.messages[-1]["role"] != "assistant":
        with st.chat_message("assistant"):
            with st.spinner("Thinking..."):
                response = conversational_qa_chain.invoke(
                {
                    "question": prompts2,
                    "chat_history": chistory,
                }
            )
                st.write(response)
        message = {"role": "assistant", "content": response}
        st.session_state.messages.append(message)

if __name__ == '__main__':
    app()