pinecone serveless migration w langchain
Browse files
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
app.py
CHANGED
@@ -3,29 +3,29 @@ import streamlit as st
|
|
3 |
import os
|
4 |
import pkg_resources
|
5 |
|
6 |
-
# Using this wacky hack to get around the massively ridicolous managed env loading order
|
7 |
-
def is_installed(package_name, version):
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
@st.cache_resource
|
15 |
-
def install_packages():
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
|
21 |
-
|
22 |
-
|
23 |
|
24 |
-
|
25 |
-
|
26 |
|
27 |
-
# install packages if necessary
|
28 |
-
install_packages()
|
29 |
|
30 |
|
31 |
import re
|
@@ -33,58 +33,96 @@ import json
|
|
33 |
from dotenv import load_dotenv
|
34 |
import numpy as np
|
35 |
import pandas as pd
|
36 |
-
|
37 |
-
|
38 |
-
from
|
39 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
-
|
|
|
|
|
|
|
42 |
|
43 |
-
# Get openai API key
|
44 |
-
openai.api_key = os.environ["OPENAI_API_KEY"]
|
45 |
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
-
#
|
|
|
50 |
|
51 |
-
# @st.cache_resource
|
52 |
-
# def get_document_store():
|
53 |
-
# doc_file_name="cpv_full_southern_africa"
|
54 |
-
# document_store = PineconeDocumentStore(api_key=pinecone_key,
|
55 |
-
# environment="asia-southeast1-gcp-free",
|
56 |
-
# index=doc_file_name)
|
57 |
-
# return document_store
|
58 |
|
59 |
|
60 |
-
# # Get (or initialize and get) the document store
|
61 |
-
# document_store = get_document_store()
|
62 |
|
|
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
)
|
76 |
-
return retriever
|
77 |
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
|
|
80 |
|
81 |
-
# # Instantiate retriever
|
82 |
-
# retriever = EmbeddingRetriever(
|
83 |
-
# document_store=document_store,
|
84 |
-
# embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
|
85 |
-
# model_format="sentence_transformers",
|
86 |
-
# progress_bar=False,
|
87 |
-
# )
|
88 |
|
89 |
prompt_template="Answer the given question using the following documents. \
|
90 |
Formulate your answer in the style of an academic report. \
|
@@ -112,40 +150,6 @@ examples = [
|
|
112 |
]
|
113 |
|
114 |
|
115 |
-
def get_docs(input_query, country = [], vulnerability_cat = []):
|
116 |
-
if not country:
|
117 |
-
country = "All Countries"
|
118 |
-
if not vulnerability_cat:
|
119 |
-
if country == "All Countries":
|
120 |
-
filters = None
|
121 |
-
else:
|
122 |
-
filters = {'country': {'$in': country}}
|
123 |
-
else:
|
124 |
-
if country == "All Countries":
|
125 |
-
filters = {'vulnerability_cat': {'$in': vulnerability_cat}}
|
126 |
-
else:
|
127 |
-
filters = {'country': {'$in': country},'vulnerability_cat': {'$in': vulnerability_cat}}
|
128 |
-
docs = retriever.retrieve(query=input_query, filters = filters, top_k = 10)
|
129 |
-
# Break out the key fields and convert to pandas for filtering
|
130 |
-
docs = [{**x.meta,"score":x.score,"content":x.content} for x in docs]
|
131 |
-
df_docs = pd.DataFrame(docs)
|
132 |
-
# Get ourselves an index setup from which to base the source reference number from (in the prompt and matching afterwards)
|
133 |
-
df_docs = df_docs.reset_index()
|
134 |
-
df_docs['ref_id'] = df_docs.index + 1 # start the index at 1
|
135 |
-
# Convert back to Document format
|
136 |
-
ls_dict = []
|
137 |
-
# Iterate over df and add relevant fields to the dict object
|
138 |
-
for index, row in df_docs.iterrows():
|
139 |
-
# Create a Document object for each row
|
140 |
-
doc = Document(
|
141 |
-
row['content'],
|
142 |
-
meta={'country': row['country'],'document': row['document'], 'page': row['page'], 'file_name': row['file_name'], 'ref_id': row['ref_id'], 'vulnerability_cat': row['vulnerability_cat'], 'score': row['score']}
|
143 |
-
)
|
144 |
-
|
145 |
-
# Append the Document object to the documents list
|
146 |
-
ls_dict.append(doc)
|
147 |
-
return ls_dict
|
148 |
-
|
149 |
def get_refs(docs, res):
|
150 |
'''
|
151 |
Parse response for engineered reference ids (refer to prompt template)
|
@@ -159,40 +163,44 @@ def get_refs(docs, res):
|
|
159 |
# extract
|
160 |
result_str = "" # Initialize an empty string to store the result
|
161 |
for i in range(len(docs)):
|
162 |
-
|
163 |
-
ref_id = doc['meta']['ref_id']
|
164 |
if ref_id in ref_ids:
|
165 |
-
if
|
166 |
-
result_str += "**Ref. " + str(ref_id) + " [" +
|
167 |
else:
|
168 |
-
result_str += "**Ref. " + str(ref_id) + " [" +
|
169 |
|
170 |
return result_str
|
171 |
|
172 |
# define a special function for putting the prompt together (as we can't use haystack)
|
173 |
def get_prompt(docs, input_query):
|
174 |
base_prompt=prompt_template
|
175 |
-
# Add the
|
176 |
-
context = ' - '.join(['&&& [ref. '+str(d.
|
177 |
prompt = base_prompt+"; Context: "+context+"; Question: "+input_query+"; Answer:"
|
178 |
return(prompt)
|
179 |
|
180 |
-
def run_query(
|
181 |
# first call the retriever function using selected filters
|
182 |
-
docs = get_docs(
|
183 |
# model selector (not currently being used)
|
184 |
if model_sel == "chatGPT":
|
185 |
# instantiate ChatCompletion as a generator object (stream is set to True)
|
186 |
-
response = openai.ChatCompletion.create(model="gpt-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
# iterate through the streamed output
|
188 |
report = []
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
report.append(chunk_message.content) # extract the message
|
195 |
-
# add the latest text and merge it with all previous
|
196 |
result = "".join(report).strip()
|
197 |
res_box.success(result) # output to response text box
|
198 |
|
|
|
3 |
import os
|
4 |
import pkg_resources
|
5 |
|
6 |
+
# # Using this wacky hack to get around the massively ridicolous managed env loading order
|
7 |
+
# def is_installed(package_name, version):
|
8 |
+
# try:
|
9 |
+
# pkg = pkg_resources.get_distribution(package_name)
|
10 |
+
# return pkg.version == version
|
11 |
+
# except pkg_resources.DistributionNotFound:
|
12 |
+
# return False
|
13 |
+
|
14 |
+
# @st.cache_resource
|
15 |
+
# def install_packages():
|
16 |
+
# install_commands = []
|
17 |
+
|
18 |
+
# if not is_installed("spaces", "0.12.0"):
|
19 |
+
# install_commands.append("pip install spaces==0.12.0")
|
20 |
|
21 |
+
# if not is_installed("pydantic", "1.8.2"):
|
22 |
+
# install_commands.append("pip install pydantic==1.8.2")
|
23 |
|
24 |
+
# if install_commands:
|
25 |
+
# os.system(" && ".join(install_commands))
|
26 |
|
27 |
+
# # install packages if necessary
|
28 |
+
# # install_packages()
|
29 |
|
30 |
|
31 |
import re
|
|
|
33 |
from dotenv import load_dotenv
|
34 |
import numpy as np
|
35 |
import pandas as pd
|
36 |
+
import getpass
|
37 |
+
import os
|
38 |
+
from dotenv import load_dotenv, find_dotenv
|
39 |
+
from pinecone import Pinecone, ServerlessSpec
|
40 |
+
from langchain_pinecone import PineconeVectorStore
|
41 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
42 |
+
# from langchain_core.output_parsers import StrOutputParser
|
43 |
+
# from langchain_core.runnables import RunnablePassthrough
|
44 |
+
# from langchain_openai import ChatOpenAI
|
45 |
+
from langchain.docstore.document import Document
|
46 |
+
from openai import OpenAI
|
47 |
+
|
48 |
+
client = OpenAI(
|
49 |
+
organization='org-x0YBcOjkdPyf6ExxWCkmFHAj',
|
50 |
+
project='proj_40oH22n9XudeKL2rgka1IQ5B',
|
51 |
+
api_key='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'
|
52 |
+
)
|
53 |
+
|
54 |
+
pinecone_api_key = os.environ.get("PINECONE_API_KEY")
|
55 |
|
56 |
+
@st.cache_resource
|
57 |
+
def initialize_embeddings(model_name: str = "all-mpnet-base-v2"):
|
58 |
+
embeddings = HuggingFaceEmbeddings(model_name=model_name)
|
59 |
+
return embeddings
|
60 |
|
|
|
|
|
61 |
|
62 |
+
@st.cache_resource
|
63 |
+
def initialize_vector_store(pinecone_api_key: str, index_name: str):
|
64 |
+
# Initialize Pinecone
|
65 |
+
pc = Pinecone(api_key=pinecone_api_key)
|
66 |
+
|
67 |
+
# Access the index
|
68 |
+
index = pc.Index(index_name)
|
69 |
+
|
70 |
+
# Use the cached embeddings
|
71 |
+
embeddings = initialize_embeddings()
|
72 |
+
|
73 |
+
# Create the vector store
|
74 |
+
vector_store = PineconeVectorStore(index=index, embedding=embeddings, text_key='content')
|
75 |
+
|
76 |
+
return vector_store, embeddings
|
77 |
|
78 |
+
# Unpack the tuple into both vector_store and embeddings
|
79 |
+
vector_store, embeddings = initialize_vector_store(pinecone_api_key, index_name="cpv-full-southern-africa-test")
|
80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
|
|
|
|
|
83 |
|
84 |
+
def get_docs(query, country = [], vulnerability_cat = []):
|
85 |
|
86 |
+
if not country:
|
87 |
+
country = "All Countries"
|
88 |
+
if not vulnerability_cat:
|
89 |
+
if country == "All Countries":
|
90 |
+
filters = None
|
91 |
+
else:
|
92 |
+
filters = {'country': {'$in': country}}
|
93 |
+
else:
|
94 |
+
if country == "All Countries":
|
95 |
+
filters = {'vulnerability_cat': {'$in': vulnerability_cat}}
|
96 |
+
else:
|
97 |
+
filters = {'country': {'$in': country},'vulnerability_cat': {'$in': vulnerability_cat}}
|
98 |
+
|
99 |
+
|
100 |
+
docs = vector_store.similarity_search_by_vector_with_score(
|
101 |
+
embeddings.embed_query(query),
|
102 |
+
k=20,
|
103 |
+
filter=filters,
|
104 |
)
|
|
|
105 |
|
106 |
+
# Break out the key fields and convert to pandas for filtering
|
107 |
+
docs_dict = [{**x[0].metadata,"score":x[1],"content":x[0].page_content} for x in docs]
|
108 |
+
df_docs = pd.DataFrame(docs_dict)
|
109 |
+
# Get ourselves an index setup from which to base the source reference number from (in the prompt and matching afterwards)
|
110 |
+
df_docs = df_docs.reset_index()
|
111 |
+
df_docs['ref_id'] = df_docs.index + 1 # start the index at 1
|
112 |
+
# Convert back to Document format
|
113 |
+
ls_dict = []
|
114 |
+
# Iterate over df and add relevant fields to the dict object
|
115 |
+
for index, row in df_docs.iterrows():
|
116 |
+
# Create a Document object for each row
|
117 |
+
doc = Document(
|
118 |
+
page_content = row['content'],
|
119 |
+
metadata={'country': row['country'],'document': row['document'], 'page': row['page'], 'file_name': row['file_name'], 'ref_id': row['ref_id'], 'vulnerability_cat': row['vulnerability_cat'], 'score': row['score']}
|
120 |
+
)
|
121 |
+
# Append the Document object to the documents list
|
122 |
+
ls_dict.append(doc)
|
123 |
|
124 |
+
return ls_dict
|
125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
prompt_template="Answer the given question using the following documents. \
|
128 |
Formulate your answer in the style of an academic report. \
|
|
|
150 |
]
|
151 |
|
152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
def get_refs(docs, res):
|
154 |
'''
|
155 |
Parse response for engineered reference ids (refer to prompt template)
|
|
|
163 |
# extract
|
164 |
result_str = "" # Initialize an empty string to store the result
|
165 |
for i in range(len(docs)):
|
166 |
+
ref_id = docs[i].metadata['ref_id']
|
|
|
167 |
if ref_id in ref_ids:
|
168 |
+
if docs[i].metadata['document'] == "Supplementary":
|
169 |
+
result_str += "**Ref. " + str(ref_id) + " [" + docs[i].metadata['country'] + " " + docs[i].metadata['document'] + ':' + docs[i].metadata['file_name'] + ' p' + str(docs[i].metadata['page']) + '; vulnerabilities: ' + docs[i].metadata['vulnerability_cat'] + "]:** " + "*'" + docs[i].page_content + "'*<br> <br>" # Add <br> for a line break
|
170 |
else:
|
171 |
+
result_str += "**Ref. " + str(ref_id) + " [" + docs[i].metadata['country'] + " " + docs[i].metadata['document'] + ' p' + str(docs[i].metadata['page']) + '; vulnerabilities: ' + docs[i].metadata['vulnerability_cat'] + "]:** " + "*'" + docs[i].page_content + "'*<br> <br>" # Add <br> for a line break
|
172 |
|
173 |
return result_str
|
174 |
|
175 |
# define a special function for putting the prompt together (as we can't use haystack)
|
176 |
def get_prompt(docs, input_query):
|
177 |
base_prompt=prompt_template
|
178 |
+
# Add the metadata data for references
|
179 |
+
context = ' - '.join(['&&& [ref. '+str(d.metadata['ref_id'])+'] '+d.metadata['document']+' &&&: '+d.page_content for d in docs])
|
180 |
prompt = base_prompt+"; Context: "+context+"; Question: "+input_query+"; Answer:"
|
181 |
return(prompt)
|
182 |
|
183 |
+
def run_query(query, country, model_sel):
|
184 |
# first call the retriever function using selected filters
|
185 |
+
docs = get_docs(query, country=country,vulnerability_cat=vulnerabilities_cat)
|
186 |
# model selector (not currently being used)
|
187 |
if model_sel == "chatGPT":
|
188 |
# instantiate ChatCompletion as a generator object (stream is set to True)
|
189 |
+
# response = openai.ChatCompletion.create(model="gpt-4o-mini-2024-07-18", messages=[{"role": "user", "content": get_prompt(docs, query)}], stream=True)
|
190 |
+
|
191 |
+
|
192 |
+
stream = client.chat.completions.create(
|
193 |
+
model="gpt-4o-mini-2024-07-18",
|
194 |
+
messages=[{"role": "user", "content": get_prompt(docs, query)}],
|
195 |
+
stream=True,
|
196 |
+
)
|
197 |
# iterate through the streamed output
|
198 |
report = []
|
199 |
+
|
200 |
+
for chunk in stream:
|
201 |
+
if chunk.choices[0].delta.content is not None:
|
202 |
+
# print(chunk.choices[0].delta.content, end="")
|
203 |
+
report.append(chunk.choices[0].delta.content)
|
|
|
|
|
204 |
result = "".join(report).strip()
|
205 |
res_box.success(result) # output to response text box
|
206 |
|
env
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
OPENAI_API_KEY="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"
|
2 |
+
HF_API_KEY="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"
|
3 |
+
PINECONE_API_KEY="c3f5717c-f43a-46d0-893e-02b44dbcf13b"
|
4 |
+
USER1_HASH="$2b$12$hZbOi6zKmQQWvvpcllds9uAB3ili66N0aQyPzuDctl7IkNhl226oG"
|
5 |
+
USER2_HASH="$2b$12$kWnArbA.2QTkpMv2yvE2J.7UJw0Fgc/3FH1k5JRqhjg.cvytriGt2"
|