mobinln commited on
Commit
786f732
1 Parent(s): 0983911

feat: remove cache, add context expander

Browse files
Files changed (4) hide show
  1. .gitignore +2 -1
  2. app.py +13 -7
  3. llm.py +38 -13
  4. vector_store.py +6 -8
.gitignore CHANGED
@@ -1,3 +1,4 @@
1
  /__pycache__
2
  /temp
3
- /models
 
 
1
  /__pycache__
2
  /temp
3
+ /models
4
+ /chroma
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import streamlit as st
3
  from llm import load_llm, response_generator
4
  from vector_store import load_vector_store, process_pdf
@@ -9,20 +10,26 @@ from uuid import uuid4
9
  repo_id = "Qwen/Qwen2.5-3B-Instruct-GGUF"
10
  filename = "qwen2.5-3b-instruct-q5_k_m.gguf"
11
 
 
12
  llm = load_llm(repo_id, filename)
 
 
13
 
14
  st.title("PDF QA")
15
  # Initialize chat history
16
  if "messages" not in st.session_state:
 
 
 
17
  st.session_state.messages = []
18
 
19
  # Display chat messages from history on app rerun
20
  for message in st.session_state.messages:
21
  with st.chat_message(message["role"]):
22
  if message["role"] == "user":
23
- st.markdown(message["content"])
24
  else:
25
- st.code(message["content"])
26
 
27
  # Accept user input
28
  if prompt := st.chat_input("What is up?"):
@@ -34,13 +41,12 @@ if prompt := st.chat_input("What is up?"):
34
 
35
  # Display assistant response in chat message container
36
  with st.chat_message("assistant"):
37
- vector_store = load_vector_store()
38
- retriever = vector_store.as_retriever()
39
- docs = retriever.get_relevant_documents(prompt)
40
-
41
  response = response_generator(llm, st.session_state.messages, prompt, retriever)
42
 
43
  st.markdown(response["answer"])
 
 
44
 
45
  # Add assistant response to chat history
46
  st.session_state.messages.append(
@@ -54,7 +60,7 @@ with st.sidebar:
54
  "Choose a PDF file", accept_multiple_files=True, type="pdf"
55
  )
56
  if uploaded_files is not None:
57
- vector_store = load_vector_store()
58
  for uploaded_file in uploaded_files:
59
  temp_dir = "./temp"
60
  if not os.path.exists(temp_dir):
 
1
  import os
2
+ import shutil
3
  import streamlit as st
4
  from llm import load_llm, response_generator
5
  from vector_store import load_vector_store, process_pdf
 
10
  repo_id = "Qwen/Qwen2.5-3B-Instruct-GGUF"
11
  filename = "qwen2.5-3b-instruct-q5_k_m.gguf"
12
 
13
+
14
  llm = load_llm(repo_id, filename)
15
+ vector_store = load_vector_store()
16
+
17
 
18
  st.title("PDF QA")
19
  # Initialize chat history
20
  if "messages" not in st.session_state:
21
+ vector_store.reset_collection()
22
+ if os.path.exists("./temp"):
23
+ shutil.rmtree("./temp")
24
  st.session_state.messages = []
25
 
26
  # Display chat messages from history on app rerun
27
  for message in st.session_state.messages:
28
  with st.chat_message(message["role"]):
29
  if message["role"] == "user":
30
+ st.write(message["content"])
31
  else:
32
+ st.write(message["content"])
33
 
34
  # Accept user input
35
  if prompt := st.chat_input("What is up?"):
 
41
 
42
  # Display assistant response in chat message container
43
  with st.chat_message("assistant"):
44
+ retriever = vector_store.as_retriever(search_kwargs={"k": 3})
 
 
 
45
  response = response_generator(llm, st.session_state.messages, prompt, retriever)
46
 
47
  st.markdown(response["answer"])
48
+ with st.expander("See context"):
49
+ st.write(response["context"])
50
 
51
  # Add assistant response to chat history
52
  st.session_state.messages.append(
 
60
  "Choose a PDF file", accept_multiple_files=True, type="pdf"
61
  )
62
  if uploaded_files is not None:
63
+ st.session_state.uploaded_pdf = True
64
  for uploaded_file in uploaded_files:
65
  temp_dir = "./temp"
66
  if not os.path.exists(temp_dir):
llm.py CHANGED
@@ -7,6 +7,10 @@ from langchain.chains import create_retrieval_chain
7
  from langchain.chains.combine_documents import create_stuff_documents_chain
8
  from langchain_core.prompts import ChatPromptTemplate
9
 
 
 
 
 
10
 
11
  @st.cache_resource()
12
  def load_llm(repo_id, filename):
@@ -29,6 +33,8 @@ def load_llm(repo_id, filename):
29
  n_threads=4,
30
  n_threads_batch=4,
31
  n_ctx=8000,
 
 
32
  )
33
  print(f"{repo_id} loaded successfully. ✅")
34
  return llm
@@ -36,26 +42,45 @@ def load_llm(repo_id, filename):
36
 
37
  # Streamed response emulator
38
  def response_generator(llm, messages, question, retriever):
 
39
  system_prompt = (
40
- "You are an assistant for question-answering tasks. "
41
- "Use the following pieces of retrieved context to answer "
42
- "the question. If you don't know the answer, say that you "
43
- "don't know. Use three sentences maximum and keep the "
44
- "answer concise."
45
  "\n\n"
46
- "{context}"
 
47
  )
48
 
49
- prompt = ChatPromptTemplate.from_messages(
50
- [
51
- ("system", system_prompt),
52
- ("user", "{input}"),
53
- ]
54
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
 
56
  question_answer_chain = create_stuff_documents_chain(llm, prompt)
57
  rag_chain = create_retrieval_chain(retriever, question_answer_chain)
58
 
59
- results = rag_chain.invoke({"input": question})
 
60
 
61
  return results
 
7
  from langchain.chains.combine_documents import create_stuff_documents_chain
8
  from langchain_core.prompts import ChatPromptTemplate
9
 
10
+ from langchain_core.globals import set_debug
11
+
12
+ set_debug(True)
13
+
14
 
15
  @st.cache_resource()
16
  def load_llm(repo_id, filename):
 
33
  n_threads=4,
34
  n_threads_batch=4,
35
  n_ctx=8000,
36
+ max_tokens=128,
37
+ # stop=["."],
38
  )
39
  print(f"{repo_id} loaded successfully. ✅")
40
  return llm
 
42
 
43
  # Streamed response emulator
44
  def response_generator(llm, messages, question, retriever):
45
+ # System prompt setting up context for the assistant
46
  system_prompt = (
47
+ "<|im_start|>system\n"
48
+ "You are an AI assistant specializing in question-answering tasks. "
49
+ "Utilize the provided context and past conversation to answer "
50
+ "the current question. If the answer is unknown, clearly state that you "
51
+ "don't know. Keep responses concise and direct."
52
  "\n\n"
53
+ "Context: {context}"
54
+ "\n<|im_end|>"
55
  )
56
 
57
+ # Prepare message history
58
+ message_history = [("system", system_prompt)]
59
+
60
+ # Append conversation history to messages
61
+ for message in messages:
62
+ if message["role"] == "user":
63
+ message_history.append(
64
+ ("user", "<|im_start|>user\n" + message["content"] + "\n<|im_end|>")
65
+ )
66
+ elif message["role"] == "assistant":
67
+ message_history.append(
68
+ (
69
+ "assistant",
70
+ "<|im_start|>assistant\n" + message["content"] + "\n<|im_end|>",
71
+ )
72
+ )
73
+
74
+ message_history.append(("assistant", "<|im_start|>assistant\n"))
75
+
76
+ # Create prompt template with full message history
77
+ prompt = ChatPromptTemplate.from_messages(message_history)
78
 
79
+ # Instantiate chains for document retrieval and question answering
80
  question_answer_chain = create_stuff_documents_chain(llm, prompt)
81
  rag_chain = create_retrieval_chain(retriever, question_answer_chain)
82
 
83
+ # Invoke RAG (retrieval-augmented generation) chain with current input
84
+ results = rag_chain.invoke({"input": question}, verbose=True)
85
 
86
  return results
vector_store.py CHANGED
@@ -1,7 +1,8 @@
1
  import streamlit as st
2
- import chromadb
3
  from langchain_huggingface import HuggingFaceEmbeddings
4
  from langchain_chroma import Chroma
 
5
  from langchain_community.document_loaders import PyPDFLoader
6
  from langchain_text_splitters import RecursiveCharacterTextSplitter
7
 
@@ -16,17 +17,14 @@ def load_embedding_model(model):
16
  return model
17
 
18
 
 
19
  def load_vector_store():
20
- """
21
- Loads a simple vector store
22
- I didn't use @st.cache because I want to
23
- load vector store on every page load
24
- """
25
- model = load_embedding_model("sentence-transformers/all-MiniLM-L6-v2")
26
- chromadb.api.client.SharedSystemClient.clear_system_cache()
27
  vector_store = Chroma(
28
  collection_name="main_store",
29
  embedding_function=model,
 
30
  )
31
  return vector_store
32
 
 
1
  import streamlit as st
2
+
3
  from langchain_huggingface import HuggingFaceEmbeddings
4
  from langchain_chroma import Chroma
5
+ from langchain_community.vectorstores import InMemoryVectorStore
6
  from langchain_community.document_loaders import PyPDFLoader
7
  from langchain_text_splitters import RecursiveCharacterTextSplitter
8
 
 
17
  return model
18
 
19
 
20
+ @st.cache_resource()
21
  def load_vector_store():
22
+ model = load_embedding_model("sentence-transformers/all-mpnet-base-v2")
23
+
 
 
 
 
 
24
  vector_store = Chroma(
25
  collection_name="main_store",
26
  embedding_function=model,
27
+ persist_directory="./chroma",
28
  )
29
  return vector_store
30