medchat2 / app.py
Sbnos's picture
cgpt 3
7db5382 verified
raw
history blame
6 kB
import streamlit as st
import os
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.llms import Together
from langchain.prompts import ChatPromptTemplate, PromptTemplate
from langchain.schema import format_document
from typing import List
from langchain.memory import ConversationBufferMemory
from langchain_community.chat_message_histories import StreamlitChatMessageHistory
import time
# Load the embedding function
model_name = "BAAI/bge-base-en"
encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity
embedding_function = HuggingFaceBgeEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs)
# Load the LLM
llm = Together(model="mistralai/Mixtral-8x22B-Instruct-v0.1", temperature=0.2, max_tokens=19096, top_k=10, together_api_key=os.environ['pilotikval'], streaming=True)
msgs = StreamlitChatMessageHistory(key="langchain_messages")
memory = ConversationBufferMemory(chat_memory=msgs)
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
def _combine_documents(docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"):
doc_strings = [format_document(doc, document_prompt) for doc in docs]
return document_separator.join(doc_strings)
chistory = []
def store_chat_history(role: str, content: str):
chistory.append({"role": role, "content": content})
def render_message_with_copy_button(role: str, content: str, key: str):
html_code = f"""
<div class="message" style="position: relative; padding-right: 40px;">
<div class="message-content">{content}</div>
<button onclick="copyToClipboard('{key}')" style="position: absolute; right: 0; top: 0; background-color: transparent; border: none; cursor: pointer;">
<img src="https://img.icons8.com/material-outlined/24/grey/copy.png" alt="Copy">
</button>
</div>
<textarea id="{key}" style="display:none;">{content}</textarea>
<script>
function copyToClipboard(key) {{
var copyText = document.getElementById(key);
copyText.style.display = "block";
copyText.select();
document.execCommand("copy");
copyText.style.display = "none";
alert("Copied to clipboard");
}}
</script>
"""
st.write(html_code, unsafe_allow_html=True)
def get_streaming_response(user_query, chat_history):
template = """
You are a knowledgeable assistant. Provide a detailed and thorough answer to the question based on the following context:
Chat history: {chat_history}
User question: {user_question}
"""
prompt = ChatPromptTemplate.from_template(template)
inputs = {
"chat_history": chat_history,
"user_question": user_query
}
chain = prompt | llm
return chain.stream(inputs)
def app():
with st.sidebar:
st.title("dochatter")
option = st.selectbox('Which retriever would you like to use?', ('General Medicine', 'RespiratoryFishman', 'RespiratoryMurray', 'MedMRCP2', 'OldMedicine'))
if option == 'RespiratoryFishman':
persist_directory = "./respfishmandbcud/"
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="fishmannotescud")
retriever = vectordb.as_retriever(search_kwargs={"k": 5})
elif option == 'RespiratoryMurray':
persist_directory = "./respmurray/"
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="respmurraynotes")
retriever = vectordb.as_retriever(search_kwargs={"k": 5})
elif option == 'MedMRCP2':
persist_directory = "./medmrcp2store/"
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="medmrcp2notes")
retriever = vectordb.as_retriever(search_kwargs={"k": 5})
elif option == 'General Medicine':
persist_directory = "./oxfordmedbookdir/"
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="oxfordmed")
retriever = vectordb.as_retriever(search_kwargs={"k": 7})
else:
persist_directory = "./mrcpchromadb/"
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="mrcppassmednotes")
retriever = vectordb.as_retriever(search_kwargs={"k": 5})
if "messages" not in st.session_state.keys():
st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}]
st.header("Ask Away!")
for i, message in enumerate(st.session_state.messages):
with st.chat_message(message["role"]):
render_message_with_copy_button(message["role"], message["content"], key=f"message-{i}")
store_chat_history(message["role"], message["content"])
user_query = st.chat_input("Say something")
if user_query:
st.session_state.messages.append({"role": "user", "content": user_query})
with st.chat_message("user"):
st.write(user_query)
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
chat_history = "\n".join([f"{msg['role']}: {msg['content']}" for msg in chistory])
try:
response_generator = get_streaming_response(user_query, chat_history)
response_text = ""
for response_part in response_generator:
response_text += response_part
st.write(response_text)
st.session_state.messages.append({"role": "assistant", "content": response_text})
except Exception as e:
st.error(f"An error occurred: {e}")
if __name__ == '__main__':
app()