acpotts commited on
Commit
31445d8
·
verified ·
1 Parent(s): ebb5ffc

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +52 -23
  2. requirements.txt +8 -4
app.py CHANGED
@@ -14,6 +14,11 @@ import chainlit as cl
14
  from langchain_text_splitters import RecursiveCharacterTextSplitter
15
  # from langchain_experimental.text_splitter import SemanticChunker
16
  # from langchain_openai.embeddings import OpenAIEmbeddings
 
 
 
 
 
17
 
18
  system_template = """\
19
  Use the following context to answer a users question. If you cannot find the answer in the context, say you don't know the answer."""
@@ -27,27 +32,27 @@ Question:
27
  """
28
  user_role_prompt = UserRolePrompt(user_prompt_template)
29
 
30
- class RetrievalAugmentedQAPipeline:
31
- def __init__(self, llm: ChatOpenAI(), vector_db_retriever: VectorDatabase) -> None:
32
- self.llm = llm
33
- self.vector_db_retriever = vector_db_retriever
34
 
35
- async def arun_pipeline(self, user_query: str):
36
- context_list = self.vector_db_retriever.search_by_text(user_query, k=4)
37
 
38
- context_prompt = ""
39
- for context in context_list:
40
- context_prompt += context[0] + "\n"
41
 
42
- formatted_system_prompt = system_role_prompt.create_message()
43
 
44
- formatted_user_prompt = user_role_prompt.create_message(question=user_query, context=context_prompt)
45
 
46
- async def generate_response():
47
- async for chunk in self.llm.astream([formatted_system_prompt, formatted_user_prompt]):
48
- yield chunk
49
 
50
- return {"response": generate_response(), "context": context_list}
51
 
52
  text_splitter = RecursiveCharacterTextSplitter()
53
 
@@ -90,6 +95,7 @@ async def on_chat_start():
90
  max_files=10
91
  ).send()
92
 
 
93
  for file in files:
94
 
95
  msg = cl.Message(
@@ -99,26 +105,49 @@ async def on_chat_start():
99
 
100
  # load the file
101
  texts = process_text_file(file)
102
-
103
  print(f"Processing {len(texts)} text chunks")
104
 
105
  # Create a dict vector store
106
- vector_db = VectorDatabase()
107
- vector_db = await vector_db.abuild_from_list(texts)
108
 
109
- chat_openai = ChatOpenAI()
110
 
111
  # Create a chain
112
- retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(
113
- vector_db_retriever=vector_db,
114
- llm=chat_openai
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  )
 
116
 
117
  # Let the user know that the system is ready
118
  msg.content = f"Processing `{file.name}` done. You can now ask questions!"
119
  await msg.update()
120
 
121
- cl.user_session.set("chain", retrieval_augmented_qa_pipeline)
122
 
123
 
124
  @cl.on_message
 
14
  from langchain_text_splitters import RecursiveCharacterTextSplitter
15
  # from langchain_experimental.text_splitter import SemanticChunker
16
  # from langchain_openai.embeddings import OpenAIEmbeddings
17
+ from sentence_transformers import SentenceTransformer
18
+ from langchain_huggingface import HuggingFaceEmbeddings
19
+ from langchain_community.vectorstores import FAISS
20
+ from langchain_openai.embeddings import OpenAIEmbeddings
21
+ from langchain_core.documents import Document
22
 
23
  system_template = """\
24
  Use the following context to answer a users question. If you cannot find the answer in the context, say you don't know the answer."""
 
32
  """
33
  user_role_prompt = UserRolePrompt(user_prompt_template)
34
 
35
+ # class RetrievalAugmentedQAPipeline:
36
+ # def __init__(self, llm: ChatOpenAI(), vector_db_retriever: VectorDatabase) -> None:
37
+ # self.llm = llm
38
+ # self.vector_db_retriever = vector_db_retriever
39
 
40
+ # async def arun_pipeline(self, user_query: str):
41
+ # context_list = self.vector_db_retriever.search_by_text(user_query, k=4)
42
 
43
+ # context_prompt = ""
44
+ # for context in context_list:
45
+ # context_prompt += context[0] + "\n"
46
 
47
+ # formatted_system_prompt = system_role_prompt.create_message()
48
 
49
+ # formatted_user_prompt = user_role_prompt.create_message(question=user_query, context=context_prompt)
50
 
51
+ # async def generate_response():
52
+ # async for chunk in self.llm.astream([formatted_system_prompt, formatted_user_prompt]):
53
+ # yield chunk
54
 
55
+ # return {"response": generate_response(), "context": context_list}
56
 
57
  text_splitter = RecursiveCharacterTextSplitter()
58
 
 
95
  max_files=10
96
  ).send()
97
 
98
+ processed_documents = []
99
  for file in files:
100
 
101
  msg = cl.Message(
 
105
 
106
  # load the file
107
  texts = process_text_file(file)
108
+ processed_documents.extend(texts)
109
  print(f"Processing {len(texts)} text chunks")
110
 
111
  # Create a dict vector store
112
+ # vector_db = VectorDatabase()
113
+ # vector_db = await vector_db.abuild_from_list(texts)
114
 
115
+ # chat_openai = ChatOpenAI()
116
 
117
  # Create a chain
118
+ # retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(
119
+ # vector_db_retriever=vector_db,
120
+ # llm=chat_openai
121
+ # )
122
+
123
+
124
+
125
+ finetune_embeddings = HuggingFaceEmbeddings(model_name="finetuned_arctic")
126
+
127
+ finetune_vectorstore = FAISS.from_documents(processed_documents, finetune_embeddings)
128
+ finetune_retriever = finetune_vectorstore.as_retriever(search_kwargs={"k": 6})
129
+
130
+ from operator import itemgetter
131
+ from langchain_core.output_parsers import StrOutputParser
132
+ from langchain_core.runnables import RunnablePassthrough, RunnableParallel
133
+
134
+ rag_llm = ChatOpenAI(
135
+ model="gpt-4o-mini",
136
+ temperature=0
137
+ )
138
+
139
+ finetune_rag_chain = (
140
+ {"context": itemgetter("question") | finetune_retriever, "question": itemgetter("question")}
141
+ | RunnablePassthrough.assign(context=itemgetter("context"))
142
+ | {"response": system_template | rag_llm | StrOutputParser(), "context": itemgetter("context")}
143
  )
144
+
145
 
146
  # Let the user know that the system is ready
147
  msg.content = f"Processing `{file.name}` done. You can now ask questions!"
148
  await msg.update()
149
 
150
+ cl.user_session.set("chain", finetune_rag_chain)
151
 
152
 
153
  @cl.on_message
requirements.txt CHANGED
@@ -1,7 +1,11 @@
1
  numpy
2
  chainlit==0.7.700
3
- openai
4
- langchain_community
5
- langchain_experimental
6
- langchain_openai
 
 
 
 
7
  pypdf
 
1
  numpy
2
  chainlit==0.7.700
3
+ # openai
4
+ # langchain_community
5
+ # langchain_experimental
6
+ # langchain_openai
7
+ # langchain_huggingface
8
+ langchain-core==0.2.40
9
+ langchain-openai==0.1.25
10
+ langchain-huggingface==0.0.3
11
  pypdf