File size: 6,000 Bytes
4b7893b
 
9cae142
 
 
69cffa4
4b7893b
69cffa4
4b7893b
7db5382
398db91
4b7893b
 
 
 
69cffa4
4b7893b
 
69cffa4
35fbc0b
4b7893b
 
 
 
 
69cffa4
398db91
 
4b7893b
 
 
 
 
 
35fbc0b
 
 
 
69cffa4
 
 
35fbc0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69cffa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cae142
69cffa4
 
4b7893b
 
 
69cffa4
4b7893b
398db91
 
4b7893b
398db91
 
 
4b7893b
398db91
 
 
4b7893b
398db91
 
 
d98df44
4b7893b
398db91
 
4b7893b
 
 
 
 
d1601e3
35fbc0b
4b7893b
35fbc0b
4b7893b
 
69cffa4
 
 
4b7893b
69cffa4
4b7893b
 
 
69cffa4
9cae142
 
 
 
 
 
 
 
 
20a674a
4b7893b
398db91
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import streamlit as st
import os
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.llms import Together
from langchain.prompts import ChatPromptTemplate, PromptTemplate
from langchain.schema import format_document
from typing import List
from langchain.memory import ConversationBufferMemory
from langchain_community.chat_message_histories import StreamlitChatMessageHistory
import time

# Load the embedding function
model_name = "BAAI/bge-base-en"
encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity
embedding_function = HuggingFaceBgeEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs)

# Load the LLM
llm = Together(model="mistralai/Mixtral-8x22B-Instruct-v0.1", temperature=0.2, max_tokens=19096, top_k=10, together_api_key=os.environ['pilotikval'], streaming=True)

msgs = StreamlitChatMessageHistory(key="langchain_messages")
memory = ConversationBufferMemory(chat_memory=msgs)

DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")

def _combine_documents(docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"):
    doc_strings = [format_document(doc, document_prompt) for doc in docs]
    return document_separator.join(doc_strings)

chistory = []

def store_chat_history(role: str, content: str):
    chistory.append({"role": role, "content": content})

def render_message_with_copy_button(role: str, content: str, key: str):
    html_code = f"""
    <div class="message" style="position: relative; padding-right: 40px;">
        <div class="message-content">{content}</div>
        <button onclick="copyToClipboard('{key}')" style="position: absolute; right: 0; top: 0; background-color: transparent; border: none; cursor: pointer;">
            <img src="https://img.icons8.com/material-outlined/24/grey/copy.png" alt="Copy">
        </button>
    </div>
    <textarea id="{key}" style="display:none;">{content}</textarea>
    <script>
        function copyToClipboard(key) {{
            var copyText = document.getElementById(key);
            copyText.style.display = "block";
            copyText.select();
            document.execCommand("copy");
            copyText.style.display = "none";
            alert("Copied to clipboard");
        }}
    </script>
    """
    st.write(html_code, unsafe_allow_html=True)

def get_streaming_response(user_query, chat_history):
    template = """
    You are a knowledgeable assistant. Provide a detailed and thorough answer to the question based on the following context:

    Chat history: {chat_history}

    User question: {user_question}
    """
    prompt = ChatPromptTemplate.from_template(template)
    
    inputs = {
        "chat_history": chat_history,
        "user_question": user_query
    }

    chain = prompt | llm
    return chain.stream(inputs)

def app():
    with st.sidebar:
        st.title("dochatter")
        option = st.selectbox('Which retriever would you like to use?', ('General Medicine', 'RespiratoryFishman', 'RespiratoryMurray', 'MedMRCP2', 'OldMedicine'))
        if option == 'RespiratoryFishman':
            persist_directory = "./respfishmandbcud/"
            vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="fishmannotescud")
            retriever = vectordb.as_retriever(search_kwargs={"k": 5})
        elif option == 'RespiratoryMurray':
            persist_directory = "./respmurray/"
            vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="respmurraynotes")
            retriever = vectordb.as_retriever(search_kwargs={"k": 5})
        elif option == 'MedMRCP2':
            persist_directory = "./medmrcp2store/"
            vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="medmrcp2notes")
            retriever = vectordb.as_retriever(search_kwargs={"k": 5})
        elif option == 'General Medicine':
            persist_directory = "./oxfordmedbookdir/"
            vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="oxfordmed")
            retriever = vectordb.as_retriever(search_kwargs={"k": 7})
        else:
            persist_directory = "./mrcpchromadb/"
            vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="mrcppassmednotes")
            retriever = vectordb.as_retriever(search_kwargs={"k": 5})

    if "messages" not in st.session_state.keys():
        st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}]

    st.header("Ask Away!")
    for i, message in enumerate(st.session_state.messages):
        with st.chat_message(message["role"]):
            render_message_with_copy_button(message["role"], message["content"], key=f"message-{i}")
            store_chat_history(message["role"], message["content"])

    user_query = st.chat_input("Say something")
    if user_query:
        st.session_state.messages.append({"role": "user", "content": user_query})
        with st.chat_message("user"):
            st.write(user_query)

        with st.chat_message("assistant"):
            with st.spinner("Thinking..."):
                chat_history = "\n".join([f"{msg['role']}: {msg['content']}" for msg in chistory])
                try:
                    response_generator = get_streaming_response(user_query, chat_history)
                    response_text = ""
                    for response_part in response_generator:
                        response_text += response_part
                        st.write(response_text)
                    st.session_state.messages.append({"role": "assistant", "content": response_text})
                except Exception as e:
                    st.error(f"An error occurred: {e}")

if __name__ == '__main__':
    app()