Last commit not found
import streamlit as st | |
import os | |
from langchain.vectorstores import Chroma | |
from langchain.embeddings import HuggingFaceBgeEmbeddings | |
from langchain_together import Together | |
from langchain import hub | |
from operator import itemgetter | |
from langchain.schema import format_document | |
from langchain.prompts import ChatPromptTemplate, PromptTemplate | |
from langchain.memory import StreamlitChatMessageHistory, ConversationBufferMemory | |
from langchain_core.runnables import RunnableLambda, RunnableParallel, RunnablePassthrough | |
# Load the embedding function | |
model_name = "BAAI/bge-base-en" | |
encode_kwargs = {'normalize_embeddings': True} | |
embedding_function = HuggingFaceBgeEmbeddings( | |
model_name=model_name, | |
encode_kwargs=encode_kwargs | |
) | |
# Initialize the LLMs | |
llm = Together( | |
model="mistralai/Mixtral-8x22B-Instruct-v0.1", | |
temperature=0.2, | |
max_new_tokens=22000, | |
top_k=12, | |
together_api_key=os.environ['pilotikval'] | |
) | |
llmc = Together( | |
model="mistralai/Mixtral-8x22B-Instruct-v0.1", | |
temperature=0.2, | |
max_new_tokens=1000, | |
top_k=3, | |
together_api_key=os.environ['pilotikval'] | |
) | |
# Memory setup | |
msgs = StreamlitChatMessageHistory(key="langchain_messages") | |
memory = ConversationBufferMemory(chat_memory=msgs) | |
# Define the prompt templates | |
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template( | |
"""Given the following conversation and a follow-up question, rephrase the follow-up question to be a standalone question. | |
Chat History: | |
{chat_history} | |
Follow Up Input: {question} | |
Standalone question:""" | |
) | |
ANSWER_PROMPT = ChatPromptTemplate.from_template( | |
"""You are helping a doctor. Answer based on the provided context: | |
{context} | |
Question: {question}""" | |
) | |
# Function to combine documents | |
def _combine_documents(docs, document_prompt=PromptTemplate.from_template("{page_content}"), document_separator="\n\n"): | |
doc_strings = [format_document(doc, document_prompt) for doc in docs] | |
return document_separator.join(doc_strings) | |
# Define the chain using LCEL | |
condense_question_chain = RunnableLambda(lambda x: {"chat_history": chistory, "question": x}) | CONDENSE_QUESTION_PROMPT | llmc | |
retriever_chain = RunnableLambda(lambda x: {"standalone_question": x}) | retriever | _combine_documents | |
answer_chain = ANSWER_PROMPT | llm | |
conversational_qa_chain = RunnableParallel( | |
condense_question=condense_question_chain, | |
retrieve=retriever_chain, | |
generate_answer=answer_chain | |
) | |
# Define the Streamlit app | |
def app(): | |
with st.sidebar: | |
st.title("dochatter") | |
option = st.selectbox( | |
'Which retriever would you like to use?', | |
('General Medicine', 'RespiratoryFishman', 'RespiratoryMurray', 'MedMRCP2', 'OldMedicine') | |
) | |
# Define retrievers based on option | |
persist_directory = { | |
'General Medicine': "./oxfordmedbookdir/", | |
'RespiratoryFishman': "./respfishmandbcud/", | |
'RespiratoryMurray': "./respmurray/", | |
'MedMRCP2': "./medmrcp2store/", | |
'OldMedicine': "./mrcpchromadb/" | |
}.get(option, "./mrcpchromadb/") | |
collection_name = { | |
'General Medicine': "oxfordmed", | |
'RespiratoryFishman': "fishmannotescud", | |
'RespiratoryMurray': "respmurraynotes", | |
'MedMRCP2': "medmrcp2notes", | |
'OldMedicine': "mrcppassmednotes" | |
}.get(option, "mrcppassmednotes") | |
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name=collection_name) | |
retriever = vectordb.as_retriever(search_kwargs={"k": 5}) | |
if "messages" not in st.session_state: | |
st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}] | |
st.header("Ask Away!") | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.write(message["content"]) | |
store_chat_history(message["role"], message["content"]) | |
prompts2 = st.chat_input("Say something") | |
if prompts2: | |
st.session_state.messages.append({"role": "user", "content": prompts2}) | |
with st.chat_message("user"): | |
st.write(prompts2) | |
if st.session_state.messages[-1]["role"] != "assistant": | |
with st.chat_message("assistant"): | |
with st.spinner("Thinking..."): | |
response = conversational_qa_chain.invoke( | |
{ | |
"question": prompts2, | |
"chat_history": chistory, | |
} | |
) | |
st.write(response) | |
message = {"role": "assistant", "content": response} | |
st.session_state.messages.append(message) | |
if __name__ == '__main__': | |
app() | |