File size: 6,053 Bytes
4b7893b 7cee51c 29eec6c b3bd92b 29eec6c ddbcdf4 398db91 4b7893b ddbcdf4 fe9f836 29eec6c fe9f836 4b7893b fe9f836 29eec6c fe9f836 ddbcdf4 29eec6c 4b7893b 29eec6c 398db91 4b7893b fe9f836 4b7893b 29eec6c 4b7893b 29eec6c fe9f836 29eec6c 4b7893b 398db91 4b7893b 398db91 4b7893b 398db91 4b7893b 398db91 d98df44 4b7893b 398db91 4b7893b ddbcdf4 29eec6c ddbcdf4 fe9f836 ddbcdf4 29eec6c ddbcdf4 29eec6c ddbcdf4 36ccd09 b3bd92b 36ccd09 b3bd92b ddbcdf4 b3bd92b 29eec6c fe9f836 29eec6c 4b7893b 29eec6c 4b7893b 29eec6c 4b7893b 29eec6c 4b7893b ddbcdf4 29eec6c b3bd92b 29eec6c b3bd92b 29eec6c ddbcdf4 29eec6c 4b7893b fe9f836 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import streamlit as st
import os
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.llms import Together
from langchain.chains import create_retrieval_chain, create_history_aware_retriever, LLMChain
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate
from langchain.memory import ConversationBufferMemory
from langchain_community.chat_message_histories import StreamlitChatMessageHistory
import time
# Load the embedding function
model_name = "BAAI/bge-base-en"
encode_kwargs = {'normalize_embeddings': True}
embedding_function = HuggingFaceBgeEmbeddings(
model_name=model_name,
encode_kwargs=encode_kwargs
)
# Load the LLM
llm = Together(
model="mistralai/Mixtral-8x22B-Instruct-v0.1",
temperature=0.2,
max_tokens=19096,
top_k=10,
together_api_key=os.environ['pilotikval']
)
# Load the summarizeLLM
llmc = Together(
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
temperature=0.2,
max_tokens=1024,
top_k=1,
together_api_key=os.environ['pilotikval']
)
msgs = StreamlitChatMessageHistory(key="langchain_messages")
memory = ConversationBufferMemory(chat_memory=msgs, memory_key="chat_history", return_messages=True)
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
def _combine_documents(
docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
):
doc_strings = [format_document(doc, document_prompt) for doc in docs]
return document_separator.join(doc_strings)
chistory = []
def store_chat_history(role: str, content: str):
chistory.append({"role": role, "content": content})
def app():
with st.sidebar:
st.title("dochatter")
option = st.selectbox(
'Which retriever would you like to use?',
('General Medicine', 'RespiratoryFishman', 'RespiratoryMurray', 'MedMRCP2', 'OldMedicine')
)
if option == 'RespiratoryFishman':
persist_directory = "./respfishmandbcud/"
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="fishmannotescud")
retriever = vectordb.as_retriever(search_kwargs={"k": 5})
elif option == 'RespiratoryMurray':
persist_directory = "./respmurray/"
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="respmurraynotes")
retriever = vectordb.as_retriever(search_kwargs={"k": 5})
elif option == 'MedMRCP2':
persist_directory = "./medmrcp2store/"
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="medmrcp2notes")
retriever = vectordb.as_retriever(search_kwargs={"k": 5})
elif option == 'General Medicine':
persist_directory = "./oxfordmedbookdir/"
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="oxfordmed")
retriever = vectordb.as_retriever(search_kwargs={"k": 7})
else:
persist_directory = "./mrcpchromadb/"
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="mrcppassmednotes")
retriever = vectordb.as_retriever(search_kwargs={"k": 5})
if "messages" not in st.session_state.keys():
st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}]
condense_template = """Given the following conversation and a follow-up question, rephrase the follow-up question to be a standalone question which contains the themes of the conversation.
Chat History:
{chat_history}
Follow-Up Input: {question}
Standalone question:"""
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(condense_template)
answer_template = """You are helping a doctor. Answer with what you know from the context provided. Please be as detailed and thorough. Answer the question based on the following context:
{context}
Question: {question}"""
ANSWER_PROMPT = ChatPromptTemplate.from_template(answer_template)
question_generator_chain = LLMChain(llm=llmc, prompt=CONDENSE_QUESTION_PROMPT)
combine_docs_chain = LLMChain(llm=llm, prompt=ANSWER_PROMPT)
history_aware_retriever = create_history_aware_retriever(
llm=llmc,
retriever=retriever,
prompt=CONDENSE_QUESTION_PROMPT
)
conversational_qa_chain = create_retrieval_chain(
history_aware_retriever,
combine_docs_chain
)
st.header("Ask Away!")
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.write(message["content"])
store_chat_history(message["role"], message["content"])
prompts2 = st.chat_input("Say something")
if prompts2:
st.session_state.messages.append({"role": "user", "content": prompts2})
with st.chat_message("user"):
st.write(prompts2)
if st.session_state.messages[-1]["role"] != "assistant":
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
for _ in range(3):
try:
response = conversational_qa_chain.invoke(
{
"input": prompts2,
"chat_history": chistory,
}
)
st.write(response["answer"])
message = {"role": "assistant", "content": response["answer"]}
st.session_state.messages.append(message)
break
except Exception as e:
st.error(f"An error occurred: {e}")
time.sleep(2)
if __name__ == '__main__':
app()
|