File size: 4,761 Bytes
4b7893b 4bfe10b ba25116 4bfe10b 15f0e8d 4b7893b f7842a2 fe9f836 29eec6c 15f0e8d 29eec6c fe9f836 f7842a2 17dad20 fe9f836 4b7893b fe9f836 4bfe10b fe9f836 f7842a2 17dad20 fe9f836 29eec6c 15f0e8d fe9f836 4bfe10b 15f0e8d 4b7893b 15f0e8d f7842a2 4b7893b 15f0e8d 4b7893b 15f0e8d 4b7893b 4bfe10b 29eec6c 4b7893b 29eec6c fe9f836 29eec6c f7842a2 15f0e8d 4bfe10b 15f0e8d f7842a2 15f0e8d f7842a2 4b7893b 15f0e8d ba25116 29eec6c 4b7893b 29eec6c 4b7893b 15f0e8d 29eec6c 4b7893b 29eec6c 4b7893b 4bfe10b 4b7893b f7842a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import streamlit as st
import os
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain_together import Together
from langchain import hub
from operator import itemgetter
from langchain.schema import format_document
from langchain.prompts import ChatPromptTemplate, PromptTemplate
from langchain.memory import StreamlitChatMessageHistory, ConversationBufferMemory
from langchain_core.runnables import RunnableLambda, RunnableParallel, RunnablePassthrough
# Load the embedding function
model_name = "BAAI/bge-base-en"
encode_kwargs = {'normalize_embeddings': True}
embedding_function = HuggingFaceBgeEmbeddings(
model_name=model_name,
encode_kwargs=encode_kwargs
)
# Initialize the LLMs
llm = Together(
model="mistralai/Mixtral-8x22B-Instruct-v0.1",
temperature=0.2,
max_new_tokens=22000,
top_k=12,
together_api_key=os.environ['pilotikval']
)
llmc = Together(
model="mistralai/Mixtral-8x22B-Instruct-v0.1",
temperature=0.2,
max_new_tokens=1000,
top_k=3,
together_api_key=os.environ['pilotikval']
)
# Memory setup
msgs = StreamlitChatMessageHistory(key="langchain_messages")
memory = ConversationBufferMemory(chat_memory=msgs)
# Define the prompt templates
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(
"""Given the following conversation and a follow-up question, rephrase the follow-up question to be a standalone question.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
)
ANSWER_PROMPT = ChatPromptTemplate.from_template(
"""You are helping a doctor. Answer based on the provided context:
{context}
Question: {question}"""
)
# Function to combine documents
def _combine_documents(docs, document_prompt=PromptTemplate.from_template("{page_content}"), document_separator="\n\n"):
doc_strings = [format_document(doc, document_prompt) for doc in docs]
return document_separator.join(doc_strings)
# Define the chain using LCEL
condense_question_chain = RunnableLambda(lambda x: {"chat_history": chistory, "question": x}) | CONDENSE_QUESTION_PROMPT | llmc
retriever_chain = RunnableLambda(lambda x: {"standalone_question": x}) | retriever | _combine_documents
answer_chain = ANSWER_PROMPT | llm
conversational_qa_chain = RunnableParallel(
condense_question=condense_question_chain,
retrieve=retriever_chain,
generate_answer=answer_chain
)
# Define the Streamlit app
def app():
with st.sidebar:
st.title("dochatter")
option = st.selectbox(
'Which retriever would you like to use?',
('General Medicine', 'RespiratoryFishman', 'RespiratoryMurray', 'MedMRCP2', 'OldMedicine')
)
# Define retrievers based on option
persist_directory = {
'General Medicine': "./oxfordmedbookdir/",
'RespiratoryFishman': "./respfishmandbcud/",
'RespiratoryMurray': "./respmurray/",
'MedMRCP2': "./medmrcp2store/",
'OldMedicine': "./mrcpchromadb/"
}.get(option, "./mrcpchromadb/")
collection_name = {
'General Medicine': "oxfordmed",
'RespiratoryFishman': "fishmannotescud",
'RespiratoryMurray': "respmurraynotes",
'MedMRCP2': "medmrcp2notes",
'OldMedicine': "mrcppassmednotes"
}.get(option, "mrcppassmednotes")
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name=collection_name)
retriever = vectordb.as_retriever(search_kwargs={"k": 5})
if "messages" not in st.session_state:
st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}]
st.header("Ask Away!")
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.write(message["content"])
store_chat_history(message["role"], message["content"])
prompts2 = st.chat_input("Say something")
if prompts2:
st.session_state.messages.append({"role": "user", "content": prompts2})
with st.chat_message("user"):
st.write(prompts2)
if st.session_state.messages[-1]["role"] != "assistant":
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
response = conversational_qa_chain.invoke(
{
"question": prompts2,
"chat_history": chistory,
}
)
st.write(response)
message = {"role": "assistant", "content": response}
st.session_state.messages.append(message)
if __name__ == '__main__':
app()
|