Spaces:
Sleeping
Sleeping
bachephysicdun
commited on
Commit
·
8dbef5d
1
Parent(s):
a605a95
implemented rag and filtered_rag
Browse files- app/chains.py +30 -24
- app/data_indexing.py +47 -16
- app/main.py +47 -23
- app/prompts.py +26 -11
- app/schemas.py +1 -1
app/chains.py
CHANGED
@@ -10,14 +10,14 @@ from prompts import (
|
|
10 |
raw_prompt,
|
11 |
raw_prompt_formatted,
|
12 |
history_prompt_formatted,
|
|
|
|
|
13 |
format_context,
|
14 |
tokenizer
|
15 |
)
|
16 |
from data_indexing import DataIndexer
|
17 |
|
18 |
|
19 |
-
# data_indexer = DataIndexer()
|
20 |
-
|
21 |
llm = HuggingFaceEndpoint(
|
22 |
repo_id="meta-llama/Meta-Llama-3-8B-Instruct",
|
23 |
huggingfacehub_api_token=os.environ['HF_TOKEN'],
|
@@ -27,31 +27,37 @@ llm = HuggingFaceEndpoint(
|
|
27 |
)
|
28 |
|
29 |
simple_chain = (raw_prompt | llm).with_types(input_type=schemas.UserQuestion)
|
30 |
-
# %%
|
31 |
|
32 |
-
|
33 |
|
34 |
-
#
|
35 |
formatted_chain = (raw_prompt_formatted | llm).with_types(input_type=schemas.UserQuestion)
|
36 |
|
37 |
-
#
|
38 |
history_chain = (history_prompt_formatted | llm).with_types(input_type=schemas.HistoryInput)
|
39 |
|
40 |
-
#
|
41 |
-
|
42 |
-
|
43 |
-
#
|
44 |
-
|
45 |
-
#
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
#
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
# filtered_rag_chain
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
raw_prompt,
|
11 |
raw_prompt_formatted,
|
12 |
history_prompt_formatted,
|
13 |
+
standalone_prompt_formatted,
|
14 |
+
rag_prompt_formatted,
|
15 |
format_context,
|
16 |
tokenizer
|
17 |
)
|
18 |
from data_indexing import DataIndexer
|
19 |
|
20 |
|
|
|
|
|
21 |
llm = HuggingFaceEndpoint(
|
22 |
repo_id="meta-llama/Meta-Llama-3-8B-Instruct",
|
23 |
huggingfacehub_api_token=os.environ['HF_TOKEN'],
|
|
|
27 |
)
|
28 |
|
29 |
simple_chain = (raw_prompt | llm).with_types(input_type=schemas.UserQuestion)
|
|
|
30 |
|
31 |
+
data_indexer = DataIndexer()
|
32 |
|
33 |
+
# create formatted_chain by piping raw_prompt_formatted and the LLM endpoint.
|
34 |
formatted_chain = (raw_prompt_formatted | llm).with_types(input_type=schemas.UserQuestion)
|
35 |
|
36 |
+
# use history_prompt_formatted and HistoryInput to create the history_chain
|
37 |
history_chain = (history_prompt_formatted | llm).with_types(input_type=schemas.HistoryInput)
|
38 |
|
39 |
+
# Let's construct the standalone_chain by piping standalone_prompt_formatted with the LLM
|
40 |
+
standalone_chain = (standalone_prompt_formatted | llm).with_types(input_type=schemas.HistoryInput)
|
41 |
+
|
42 |
+
# store the result of standalone_chain chain in the variable "new_question". using the variable input_1
|
43 |
+
input_1 = RunnablePassthrough.assign(new_question=standalone_chain)
|
44 |
+
# store the result of the search and pull new_question into the standalone_question
|
45 |
+
input_2 = {
|
46 |
+
'context': lambda x: format_context(data_indexer.search(x['new_question'])),
|
47 |
+
'standalone_question': lambda x: x['new_question']
|
48 |
+
}
|
49 |
+
input_to_rag_chain = input_1 | input_2
|
50 |
+
|
51 |
+
# use input_to_rag_chain, rag_prompt_formatted,
|
52 |
+
# HistoryInput and the LLM to build the rag_chain.
|
53 |
+
rag_chain = (input_to_rag_chain | rag_prompt_formatted | llm).with_types(input_type=schemas.HistoryInput)
|
54 |
+
|
55 |
+
# Implement the filtered_rag_chain. It should be the
|
56 |
+
# same as the rag_chain but with hybrid_search = True.
|
57 |
+
input_1 = RunnablePassthrough.assign(new_question=standalone_chain)
|
58 |
+
input_2 = {
|
59 |
+
'context': lambda x: format_context(data_indexer.search(x['new_question'], hybrid_search=True)),
|
60 |
+
'standalone_question': lambda x: x['new_question']
|
61 |
+
}
|
62 |
+
input_to_filtered_rag_chain = input_1 | input_2
|
63 |
+
filtered_rag_chain = (input_to_filtered_rag_chain | rag_prompt_formatted | llm).with_types(input_type=schemas.HistoryInput)
|
app/data_indexing.py
CHANGED
@@ -6,14 +6,19 @@ from pinecone import ServerlessSpec
|
|
6 |
from langchain_community.vectorstores import Chroma
|
7 |
from langchain_openai import OpenAIEmbeddings
|
8 |
|
9 |
-
|
|
|
|
|
|
|
|
|
10 |
|
|
|
11 |
|
12 |
class DataIndexer:
|
13 |
|
14 |
source_file = os.path.join(current_dir, 'sources.txt')
|
15 |
|
16 |
-
def __init__(self, index_name='langchain-repo')
|
17 |
|
18 |
# TODO: choose your embedding model
|
19 |
# self.embedding_client = InferenceClient(
|
@@ -25,13 +30,20 @@ class DataIndexer:
|
|
25 |
self.pinecone_client = Pinecone(api_key=os.environ.get('PINECONE_API_KEY'))
|
26 |
|
27 |
if index_name not in self.pinecone_client.list_indexes().names():
|
|
|
28 |
# TODO: create your index if it doesn't exist. Use the create_index function.
|
29 |
# Make sure to choose the dimension that corresponds to your embedding model
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
self.index = self.pinecone_client.Index(self.index_name)
|
33 |
# TODO: make sure to build the index.
|
34 |
-
self.source_index =
|
35 |
|
36 |
def get_source_index(self):
|
37 |
if not os.path.isfile(self.source_file):
|
@@ -58,8 +70,8 @@ class DataIndexer:
|
|
58 |
for i in range(0, len(docs), batch_size):
|
59 |
batch = docs[i: i + batch_size]
|
60 |
|
61 |
-
#
|
62 |
-
#
|
63 |
# values = self.embedding_client.embed_documents([
|
64 |
# doc.page_content for doc in batch
|
65 |
# ])
|
@@ -67,14 +79,19 @@ class DataIndexer:
|
|
67 |
# values = self.embedding_client.feature_extraction([
|
68 |
# doc.page_content for doc in batch
|
69 |
# ])
|
70 |
-
|
|
|
|
|
|
|
71 |
|
72 |
-
#
|
73 |
-
vector_ids =
|
74 |
|
75 |
-
#
|
76 |
# with the "text" key, and make sure to capture the rest of the doc.metadata.
|
77 |
-
metadatas =
|
|
|
|
|
78 |
|
79 |
# create a list of dictionaries with keys "id" (the unique identifiers), "values"
|
80 |
# (the vector representation), and "metadata" (the metadata).
|
@@ -86,7 +103,7 @@ class DataIndexer:
|
|
86 |
|
87 |
try:
|
88 |
# TODO: Use the function upsert to upload the data to the database.
|
89 |
-
upsert_response =
|
90 |
print(upsert_response)
|
91 |
except Exception as e:
|
92 |
print(e)
|
@@ -104,16 +121,25 @@ class DataIndexer:
|
|
104 |
# TODO: choose your embedding model
|
105 |
# vector = self.embedding_client.feature_extraction(text_query)
|
106 |
# vector = self.embedding_client.embed_query(text_query)
|
107 |
-
vector =
|
108 |
|
109 |
# TODO: use the vector representation of the text_query to
|
110 |
# search the database by using the query function.
|
111 |
-
result =
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
docs = []
|
114 |
for res in result["matches"]:
|
115 |
# TODO: From the result's metadata, extract the "text" element.
|
116 |
-
|
|
|
|
|
|
|
117 |
|
118 |
return docs
|
119 |
|
@@ -126,12 +152,14 @@ if __name__ == '__main__':
|
|
126 |
RecursiveCharacterTextSplitter,
|
127 |
)
|
128 |
|
|
|
129 |
loader = GitLoader(
|
130 |
clone_url="https://github.com/langchain-ai/langchain",
|
131 |
repo_path="./code_data/langchain_repo/",
|
132 |
branch="master",
|
133 |
)
|
134 |
|
|
|
135 |
python_splitter = RecursiveCharacterTextSplitter.from_language(
|
136 |
language=Language.PYTHON, chunk_size=10000, chunk_overlap=100
|
137 |
)
|
@@ -143,8 +171,11 @@ if __name__ == '__main__':
|
|
143 |
for doc in docs:
|
144 |
doc.page_content = '# {}\n\n'.format(doc.metadata['source']) + doc.page_content
|
145 |
|
|
|
146 |
indexer = DataIndexer()
|
147 |
-
|
|
|
|
|
148 |
for doc in docs:
|
149 |
file.writelines(doc.metadata['source'] + '\n')
|
150 |
indexer.index_data(docs)
|
|
|
6 |
from langchain_community.vectorstores import Chroma
|
7 |
from langchain_openai import OpenAIEmbeddings
|
8 |
|
9 |
+
from dotenv import load_dotenv
|
10 |
+
# Specify the path to the .env file two directories up
|
11 |
+
env_path = Path(__file__).resolve().parents[2] / '.env'
|
12 |
+
load_dotenv(dotenv_path=env_path)
|
13 |
+
|
14 |
|
15 |
+
current_dir = Path(__file__).resolve().parent
|
16 |
|
17 |
class DataIndexer:
|
18 |
|
19 |
source_file = os.path.join(current_dir, 'sources.txt')
|
20 |
|
21 |
+
def __init__(self, index_name='langchain-repo'):
|
22 |
|
23 |
# TODO: choose your embedding model
|
24 |
# self.embedding_client = InferenceClient(
|
|
|
30 |
self.pinecone_client = Pinecone(api_key=os.environ.get('PINECONE_API_KEY'))
|
31 |
|
32 |
if index_name not in self.pinecone_client.list_indexes().names():
|
33 |
+
|
34 |
# TODO: create your index if it doesn't exist. Use the create_index function.
|
35 |
# Make sure to choose the dimension that corresponds to your embedding model
|
36 |
+
self.pinecone_client.create_index(
|
37 |
+
name=index_name,
|
38 |
+
dimension=1536,
|
39 |
+
metric='cosine',
|
40 |
+
spec=ServerlessSpec(cloud='aws', region='us-east-1')
|
41 |
+
)
|
42 |
+
|
43 |
|
44 |
self.index = self.pinecone_client.Index(self.index_name)
|
45 |
# TODO: make sure to build the index.
|
46 |
+
self.source_index = self.get_source_index()
|
47 |
|
48 |
def get_source_index(self):
|
49 |
if not os.path.isfile(self.source_file):
|
|
|
70 |
for i in range(0, len(docs), batch_size):
|
71 |
batch = docs[i: i + batch_size]
|
72 |
|
73 |
+
# create a list of the vector representations of each text data in the batch
|
74 |
+
# based on the selected model, choose you extract values
|
75 |
# values = self.embedding_client.embed_documents([
|
76 |
# doc.page_content for doc in batch
|
77 |
# ])
|
|
|
79 |
# values = self.embedding_client.feature_extraction([
|
80 |
# doc.page_content for doc in batch
|
81 |
# ])
|
82 |
+
|
83 |
+
values = self.embedding_client.embed_documents([
|
84 |
+
doc.page_content for doc in batch
|
85 |
+
]) # list of vectors -> vector presentation of the doc
|
86 |
|
87 |
+
# create a list of unique identifiers for each element in the batch with the uuid package.
|
88 |
+
vector_ids = [str(uuid.uuid4()) for _ in batch]
|
89 |
|
90 |
+
# create a list of dictionaries representing the metadata. Capture the text data
|
91 |
# with the "text" key, and make sure to capture the rest of the doc.metadata.
|
92 |
+
metadatas = [{
|
93 |
+
'text': doc.page_content, **doc.metadata
|
94 |
+
} for doc in batch]
|
95 |
|
96 |
# create a list of dictionaries with keys "id" (the unique identifiers), "values"
|
97 |
# (the vector representation), and "metadata" (the metadata).
|
|
|
103 |
|
104 |
try:
|
105 |
# TODO: Use the function upsert to upload the data to the database.
|
106 |
+
upsert_response = self.index.upsert(vectors=vectors)
|
107 |
print(upsert_response)
|
108 |
except Exception as e:
|
109 |
print(e)
|
|
|
121 |
# TODO: choose your embedding model
|
122 |
# vector = self.embedding_client.feature_extraction(text_query)
|
123 |
# vector = self.embedding_client.embed_query(text_query)
|
124 |
+
vector = self.embedding_client.embed_query(text_query)
|
125 |
|
126 |
# TODO: use the vector representation of the text_query to
|
127 |
# search the database by using the query function.
|
128 |
+
result = self.index.query(
|
129 |
+
# namespace=self.index_name,
|
130 |
+
vector=vector,
|
131 |
+
filter=filter,
|
132 |
+
top_k=top_k,
|
133 |
+
include_metadata=True,
|
134 |
+
)
|
135 |
|
136 |
docs = []
|
137 |
for res in result["matches"]:
|
138 |
# TODO: From the result's metadata, extract the "text" element.
|
139 |
+
metadata = res['metadata']
|
140 |
+
if 'text' in metadata:
|
141 |
+
text = metadata.pop('text')
|
142 |
+
docs.append(text)
|
143 |
|
144 |
return docs
|
145 |
|
|
|
152 |
RecursiveCharacterTextSplitter,
|
153 |
)
|
154 |
|
155 |
+
print('start the GitLoader')
|
156 |
loader = GitLoader(
|
157 |
clone_url="https://github.com/langchain-ai/langchain",
|
158 |
repo_path="./code_data/langchain_repo/",
|
159 |
branch="master",
|
160 |
)
|
161 |
|
162 |
+
print('perfrom python splitter')
|
163 |
python_splitter = RecursiveCharacterTextSplitter.from_language(
|
164 |
language=Language.PYTHON, chunk_size=10000, chunk_overlap=100
|
165 |
)
|
|
|
171 |
for doc in docs:
|
172 |
doc.page_content = '# {}\n\n'.format(doc.metadata['source']) + doc.page_content
|
173 |
|
174 |
+
print('instantiat the data indexer')
|
175 |
indexer = DataIndexer()
|
176 |
+
|
177 |
+
# with open('/app/sources.txt', 'a') as file:
|
178 |
+
with open(indexer.source_file, 'a') as file:
|
179 |
for doc in docs:
|
180 |
file.writelines(doc.metadata['source'] + '\n')
|
181 |
indexer.index_data(docs)
|
app/main.py
CHANGED
@@ -11,7 +11,7 @@ from typing import List
|
|
11 |
from sqlalchemy.orm import Session
|
12 |
|
13 |
import schemas
|
14 |
-
from chains import simple_chain, formatted_chain, history_chain
|
15 |
import crud, models, schemas, prompts
|
16 |
from database import SessionLocal, engine
|
17 |
from callbacks import LogResponseCallback
|
@@ -114,30 +114,54 @@ async def history_stream(request: Request, db: Session = Depends(get_db)):
|
|
114 |
))
|
115 |
|
116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
|
119 |
-
|
120 |
-
|
121 |
-
#
|
122 |
-
#
|
123 |
-
#
|
124 |
-
#
|
125 |
-
#
|
126 |
-
#
|
127 |
-
#
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
|
|
141 |
|
142 |
|
143 |
# Run From the Parent Directory with Script
|
|
|
11 |
from sqlalchemy.orm import Session
|
12 |
|
13 |
import schemas
|
14 |
+
from chains import simple_chain, formatted_chain, history_chain, rag_chain, filtered_rag_chain
|
15 |
import crud, models, schemas, prompts
|
16 |
from database import SessionLocal, engine
|
17 |
from callbacks import LogResponseCallback
|
|
|
114 |
))
|
115 |
|
116 |
|
117 |
+
@app.post("/rag/stream")
|
118 |
+
async def rag_stream(request: Request, db: Session = Depends(get_db)):
|
119 |
+
# TODO: Let's implement the "/rag/stream" endpoint. The endpoint should follow those steps:
|
120 |
+
# - The endpoint receives the request
|
121 |
+
# - The request is parsed into a user request
|
122 |
+
# - The user request is used to pull the chat history of the user
|
123 |
+
# - We add as part of the user history the current question by using add_message.
|
124 |
+
# - We create an instance of HistoryInput by using format_chat_history.
|
125 |
+
# - We use the history input within the rag chain.
|
126 |
+
|
127 |
+
data = await request.json()
|
128 |
+
user_request = schemas.UserRequest(**data['input'])
|
129 |
+
chat_history = crud.get_user_chat_history(db=db, username=user_request.username)
|
130 |
+
message = schemas.MessageBase(message=user_request.question, type='User', timestamp=datetime.now())
|
131 |
+
crud.add_message(db, message=message, username=user_request.username)
|
132 |
+
rag_input = schemas.HistoryInput(
|
133 |
+
question=user_request.question,
|
134 |
+
chat_history=prompts.format_chat_history(chat_history)
|
135 |
+
)
|
136 |
+
|
137 |
+
return EventSourceResponse(generate_stream(
|
138 |
+
rag_input, rag_chain, [LogResponseCallback(user_request, db)]
|
139 |
+
))
|
140 |
|
141 |
|
142 |
+
@app.post("/filtered_rag/stream")
|
143 |
+
async def filtered_rag_stream(request: Request, db: Session = Depends(get_db)):
|
144 |
+
# TODO: Let's implement the "/filtered_rag/stream" endpoint. The endpoint should follow those steps:
|
145 |
+
# - The endpoint receives the request
|
146 |
+
# - The request is parsed into a user request
|
147 |
+
# - The user request is used to pull the chat history of the user
|
148 |
+
# - We add as part of the user history the current question by using add_message.
|
149 |
+
# - We create an instance of HistoryInput by using format_chat_history.
|
150 |
+
# - We use the history input within the filtered rag chain.
|
151 |
+
|
152 |
+
data = await request.json()
|
153 |
+
user_request = schemas.UserRequest(**data['input'])
|
154 |
+
chat_history = crud.get_user_chat_history(db=db, username=user_request.username)
|
155 |
+
message = schemas.MessageBase(message=user_request.question, type='User', timestamp=datetime.now())
|
156 |
+
crud.add_message(db, message=message, username=user_request.username)
|
157 |
+
rag_input = schemas.HistoryInput(
|
158 |
+
question=user_request.question,
|
159 |
+
chat_history=prompts.format_chat_history(chat_history)
|
160 |
+
)
|
161 |
+
|
162 |
+
return EventSourceResponse(generate_stream(
|
163 |
+
rag_input, filtered_rag_chain, [LogResponseCallback(user_request, db)]
|
164 |
+
))
|
165 |
|
166 |
|
167 |
# Run From the Parent Directory with Script
|
app/prompts.py
CHANGED
@@ -53,7 +53,7 @@ def format_context(docs: List[str]):
|
|
53 |
# so we need to concatenate that list into a text that can fit into
|
54 |
# the rag_prompt_formatted. Implement format_context that takes a
|
55 |
# like of strings and returns the context as one string.
|
56 |
-
|
57 |
|
58 |
prompt = "{question}"
|
59 |
|
@@ -70,15 +70,29 @@ Follow Up Question: {question}
|
|
70 |
helpful answer:
|
71 |
"""
|
72 |
|
73 |
-
#
|
74 |
# to generate a standalone question. It needs a {chat_history} placeholder and a {question} placeholder,
|
75 |
-
standalone_prompt: str =
|
|
|
76 |
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
# a final answer to the question.
|
79 |
-
rag_prompt: str =
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
-
#
|
82 |
#raw_prompt_formatted = format_prompt(raw_prompt)
|
83 |
#raw_prompt = PromptTemplate.from_template(raw_prompt)
|
84 |
|
@@ -89,10 +103,11 @@ raw_prompt = PromptTemplate.from_template(prompt)
|
|
89 |
raw_prompt_formatted = format_prompt(prompt)
|
90 |
|
91 |
|
92 |
-
#
|
93 |
history_prompt_formatted = format_prompt(history_prompt)
|
94 |
|
95 |
-
#
|
96 |
-
standalone_prompt_formatted
|
97 |
-
|
98 |
-
|
|
|
|
53 |
# so we need to concatenate that list into a text that can fit into
|
54 |
# the rag_prompt_formatted. Implement format_context that takes a
|
55 |
# like of strings and returns the context as one string.
|
56 |
+
return '\n\n'.join(docs)
|
57 |
|
58 |
prompt = "{question}"
|
59 |
|
|
|
70 |
helpful answer:
|
71 |
"""
|
72 |
|
73 |
+
# Create the standalone_prompt prompt that will capture the question and the chat history
|
74 |
# to generate a standalone question. It needs a {chat_history} placeholder and a {question} placeholder,
|
75 |
+
standalone_prompt: str = """
|
76 |
+
Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.
|
77 |
|
78 |
+
Chat History:
|
79 |
+
{chat_history}
|
80 |
+
|
81 |
+
Follow Up Input: {question}
|
82 |
+
|
83 |
+
Standalone question:
|
84 |
+
"""
|
85 |
+
|
86 |
+
# Create the rag_prompt that will capture the context and the standalone question to generate
|
87 |
# a final answer to the question.
|
88 |
+
rag_prompt: str = """
|
89 |
+
Answer the question based only on the following context:
|
90 |
+
{context}
|
91 |
+
|
92 |
+
Question: {standalone_question}
|
93 |
+
"""
|
94 |
|
95 |
+
# create raw_prompt_formatted by using format_prompt
|
96 |
#raw_prompt_formatted = format_prompt(raw_prompt)
|
97 |
#raw_prompt = PromptTemplate.from_template(raw_prompt)
|
98 |
|
|
|
103 |
raw_prompt_formatted = format_prompt(prompt)
|
104 |
|
105 |
|
106 |
+
# use format_prompt to create history_prompt_formatted
|
107 |
history_prompt_formatted = format_prompt(history_prompt)
|
108 |
|
109 |
+
# use format_prompt to create standalone_prompt_formatted
|
110 |
+
standalone_prompt_formatted = format_prompt(standalone_prompt)
|
111 |
+
|
112 |
+
# use format_prompt to create rag_prompt_formatted
|
113 |
+
rag_prompt_formatted = format_prompt(rag_prompt)
|
app/schemas.py
CHANGED
@@ -15,7 +15,7 @@ class UserRequest(BaseModel):
|
|
15 |
username: str
|
16 |
question: str
|
17 |
|
18 |
-
#
|
19 |
# FastAPI data model. Basically MessageBase should have the same attributes as models.Message
|
20 |
class MessageBase(BaseModel):
|
21 |
# id: int
|
|
|
15 |
username: str
|
16 |
question: str
|
17 |
|
18 |
+
# implement MessageBase as a schema mapping from the database model to the
|
19 |
# FastAPI data model. Basically MessageBase should have the same attributes as models.Message
|
20 |
class MessageBase(BaseModel):
|
21 |
# id: int
|