Sbnos commited on
Commit
ddbcdf4
·
verified ·
1 Parent(s): 5d8b8f9

changing things for better retrieval

Browse files
Files changed (1) hide show
  1. app.py +19 -39
app.py CHANGED
@@ -5,23 +5,16 @@ from langchain_community.embeddings import HuggingFaceBgeEmbeddings
5
  from langchain_community.llms import Together
6
  from langchain import hub
7
  from operator import itemgetter
8
- from langchain.schema.runnable import RunnableParallel
9
- from langchain.schema import format_document
10
- from typing import List, Tuple
11
  from langchain.chains import LLMChain
12
- from langchain.chains import RetrievalQA
13
- from langchain.schema.output_parser import StrOutputParser
14
- from langchain_community.chat_message_histories import StreamlitChatMessageHistory
15
- from langchain.memory import ConversationBufferMemory
16
  from langchain.chains import ConversationalRetrievalChain
17
- from langchain.memory import ConversationSummaryMemory
18
  from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate
19
- from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
 
20
  import time
21
 
22
  # Load the embedding function
23
  model_name = "BAAI/bge-base-en"
24
- encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity
25
 
26
  embedding_function = HuggingFaceBgeEmbeddings(
27
  model_name=model_name,
@@ -47,7 +40,7 @@ llmc = Together(
47
  )
48
 
49
  msgs = StreamlitChatMessageHistory(key="langchain_messages")
50
- memory = ConversationBufferMemory(chat_memory=msgs)
51
 
52
  DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
53
 
@@ -60,19 +53,15 @@ def _combine_documents(
60
  chistory = []
61
 
62
  def store_chat_history(role: str, content: str):
63
- # Append the new message to the chat history
64
  chistory.append({"role": role, "content": content})
65
 
66
- # Define the Streamlit app
67
  def app():
68
  with st.sidebar:
69
  st.title("dochatter")
70
- # Create a dropdown selection box
71
  option = st.selectbox(
72
  'Which retriever would you like to use?',
73
  ('General Medicine', 'RespiratoryFishman', 'RespiratoryMurray', 'MedMRCP2', 'OldMedicine')
74
  )
75
- # Depending on the selected option, choose the appropriate retriever
76
  if option == 'RespiratoryFishman':
77
  persist_directory = "./respfishmandbcud/"
78
  vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="fishmannotescud")
@@ -94,37 +83,28 @@ def app():
94
  vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="mrcppassmednotes")
95
  retriever = vectordb.as_retriever(search_kwargs={"k": 5})
96
 
97
- # Session State
98
  if "messages" not in st.session_state.keys():
99
  st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}]
100
 
101
- _template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question which contains the themes of the conversation. Do not write the question. Do not write the answer.
102
  Chat History:
103
  {chat_history}
104
- Follow Up Input: {question}
105
  Standalone question:"""
106
- CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
107
 
108
- template = """You are helping a doctor. Answer with what you know from the context provided. Please be as detailed and thorough. Answer the question based on the following context:
109
  {context}
110
- Question: {question}
111
- """
112
- ANSWER_PROMPT = ChatPromptTemplate.from_template(template)
113
-
114
- _inputs = RunnableParallel(
115
- standalone_question=RunnablePassthrough.assign(
116
- chat_history=lambda x: chistory
117
- )
118
- | CONDENSE_QUESTION_PROMPT
119
- | llmc
120
- | StrOutputParser(),
121
  )
122
- _context = {
123
- "context": itemgetter("standalone_question") | retriever | _combine_documents,
124
- "question": lambda x: x["standalone_question"],
125
- }
126
-
127
- conversational_qa_chain = _inputs | _context | ANSWER_PROMPT | llm
128
 
129
  st.header("Ask Away!")
130
  for message in st.session_state.messages:
@@ -142,7 +122,7 @@ def app():
142
  if st.session_state.messages[-1]["role"] != "assistant":
143
  with st.chat_message("assistant"):
144
  with st.spinner("Thinking..."):
145
- for _ in range(3): # Retry up to 3 times
146
  try:
147
  response = conversational_qa_chain.invoke(
148
  {
@@ -156,7 +136,7 @@ def app():
156
  break
157
  except Exception as e:
158
  st.error(f"An error occurred: {e}")
159
- time.sleep(2) # Wait 2 seconds before retrying
160
 
161
  if __name__ == '__main__':
162
  app()
 
5
  from langchain_community.llms import Together
6
  from langchain import hub
7
  from operator import itemgetter
 
 
 
8
  from langchain.chains import LLMChain
 
 
 
 
9
  from langchain.chains import ConversationalRetrievalChain
 
10
  from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate
11
+ from langchain.memory import ConversationBufferMemory
12
+ from langchain_community.chat_message_histories import StreamlitChatMessageHistory
13
  import time
14
 
15
  # Load the embedding function
16
  model_name = "BAAI/bge-base-en"
17
+ encode_kwargs = {'normalize_embeddings': True}
18
 
19
  embedding_function = HuggingFaceBgeEmbeddings(
20
  model_name=model_name,
 
40
  )
41
 
42
  msgs = StreamlitChatMessageHistory(key="langchain_messages")
43
+ memory = ConversationBufferMemory(chat_memory=msgs, memory_key="chat_history", return_messages=True)
44
 
45
  DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
46
 
 
53
  chistory = []
54
 
55
  def store_chat_history(role: str, content: str):
 
56
  chistory.append({"role": role, "content": content})
57
 
 
58
  def app():
59
  with st.sidebar:
60
  st.title("dochatter")
 
61
  option = st.selectbox(
62
  'Which retriever would you like to use?',
63
  ('General Medicine', 'RespiratoryFishman', 'RespiratoryMurray', 'MedMRCP2', 'OldMedicine')
64
  )
 
65
  if option == 'RespiratoryFishman':
66
  persist_directory = "./respfishmandbcud/"
67
  vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="fishmannotescud")
 
83
  vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="mrcppassmednotes")
84
  retriever = vectordb.as_retriever(search_kwargs={"k": 5})
85
 
 
86
  if "messages" not in st.session_state.keys():
87
  st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}]
88
 
89
+ condense_template = """Given the following conversation and a follow-up question, rephrase the follow-up question to be a standalone question which contains the themes of the conversation.
90
  Chat History:
91
  {chat_history}
92
+ Follow-Up Input: {question}
93
  Standalone question:"""
94
+ CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(condense_template)
95
 
96
+ answer_template = """You are helping a doctor. Answer with what you know from the context provided. Please be as detailed and thorough. Answer the question based on the following context:
97
  {context}
98
+ Question: {question}"""
99
+ ANSWER_PROMPT = ChatPromptTemplate.from_template(answer_template)
100
+
101
+ conversational_qa_chain = ConversationalRetrievalChain(
102
+ retriever=retriever,
103
+ memory=memory,
104
+ combine_docs_chain=_combine_documents,
105
+ condense_question_chain=LLMChain(llm=llmc, prompt=CONDENSE_QUESTION_PROMPT),
106
+ qa_chain=LLMChain(llm=llm, prompt=ANSWER_PROMPT)
 
 
107
  )
 
 
 
 
 
 
108
 
109
  st.header("Ask Away!")
110
  for message in st.session_state.messages:
 
122
  if st.session_state.messages[-1]["role"] != "assistant":
123
  with st.chat_message("assistant"):
124
  with st.spinner("Thinking..."):
125
+ for _ in range(3):
126
  try:
127
  response = conversational_qa_chain.invoke(
128
  {
 
136
  break
137
  except Exception as e:
138
  st.error(f"An error occurred: {e}")
139
+ time.sleep(2)
140
 
141
  if __name__ == '__main__':
142
  app()