mtyrrell commited on
Commit
cd9150d
·
1 Parent(s): ab45f35

pinecone serveless migration w langchain

Browse files
Files changed (3) hide show
  1. .DS_Store +0 -0
  2. app.py +118 -110
  3. env +5 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
app.py CHANGED
@@ -3,29 +3,29 @@ import streamlit as st
3
  import os
4
  import pkg_resources
5
 
6
- # Using this wacky hack to get around the massively ridicolous managed env loading order
7
- def is_installed(package_name, version):
8
- try:
9
- pkg = pkg_resources.get_distribution(package_name)
10
- return pkg.version == version
11
- except pkg_resources.DistributionNotFound:
12
- return False
13
-
14
- @st.cache_resource
15
- def install_packages():
16
- install_commands = []
17
-
18
- if not is_installed("spaces", "0.12.0"):
19
- install_commands.append("pip install spaces==0.12.0")
20
 
21
- if not is_installed("pydantic", "1.8.2"):
22
- install_commands.append("pip install pydantic==1.8.2")
23
 
24
- if install_commands:
25
- os.system(" && ".join(install_commands))
26
 
27
- # install packages if necessary
28
- install_packages()
29
 
30
 
31
  import re
@@ -33,58 +33,96 @@ import json
33
  from dotenv import load_dotenv
34
  import numpy as np
35
  import pandas as pd
36
- from haystack.schema import Document
37
- from haystack.document_stores import PineconeDocumentStore
38
- from haystack.nodes import EmbeddingRetriever
39
- import openai
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- # for local st testing, may need to run source ~/.zshrc to point to env vars
 
 
 
42
 
43
- # Get openai API key
44
- openai.api_key = os.environ["OPENAI_API_KEY"]
45
 
46
- # Get openai API key
47
- pinecone_key = os.environ["PINECONE_API_KEY"]
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- #___________________________________________________________________________________________________________
 
50
 
51
- # @st.cache_resource
52
- # def get_document_store():
53
- # doc_file_name="cpv_full_southern_africa"
54
- # document_store = PineconeDocumentStore(api_key=pinecone_key,
55
- # environment="asia-southeast1-gcp-free",
56
- # index=doc_file_name)
57
- # return document_store
58
 
59
 
60
- # # Get (or initialize and get) the document store
61
- # document_store = get_document_store()
62
 
 
63
 
64
- @st.cache_resource
65
- def get_retriever():
66
- doc_file_name="cpv_full_southern_africa"
67
- document_store = PineconeDocumentStore(api_key=pinecone_key,
68
- environment="asia-southeast1-gcp-free",
69
- index=doc_file_name)
70
- retriever = EmbeddingRetriever(
71
- document_store=document_store,
72
- embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
73
- model_format="sentence_transformers",
74
- progress_bar=False,
 
 
 
 
 
 
 
75
  )
76
- return retriever
77
 
78
- retriever = get_retriever()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
 
80
 
81
- # # Instantiate retriever
82
- # retriever = EmbeddingRetriever(
83
- # document_store=document_store,
84
- # embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
85
- # model_format="sentence_transformers",
86
- # progress_bar=False,
87
- # )
88
 
89
  prompt_template="Answer the given question using the following documents. \
90
  Formulate your answer in the style of an academic report. \
@@ -112,40 +150,6 @@ examples = [
112
  ]
113
 
114
 
115
- def get_docs(input_query, country = [], vulnerability_cat = []):
116
- if not country:
117
- country = "All Countries"
118
- if not vulnerability_cat:
119
- if country == "All Countries":
120
- filters = None
121
- else:
122
- filters = {'country': {'$in': country}}
123
- else:
124
- if country == "All Countries":
125
- filters = {'vulnerability_cat': {'$in': vulnerability_cat}}
126
- else:
127
- filters = {'country': {'$in': country},'vulnerability_cat': {'$in': vulnerability_cat}}
128
- docs = retriever.retrieve(query=input_query, filters = filters, top_k = 10)
129
- # Break out the key fields and convert to pandas for filtering
130
- docs = [{**x.meta,"score":x.score,"content":x.content} for x in docs]
131
- df_docs = pd.DataFrame(docs)
132
- # Get ourselves an index setup from which to base the source reference number from (in the prompt and matching afterwards)
133
- df_docs = df_docs.reset_index()
134
- df_docs['ref_id'] = df_docs.index + 1 # start the index at 1
135
- # Convert back to Document format
136
- ls_dict = []
137
- # Iterate over df and add relevant fields to the dict object
138
- for index, row in df_docs.iterrows():
139
- # Create a Document object for each row
140
- doc = Document(
141
- row['content'],
142
- meta={'country': row['country'],'document': row['document'], 'page': row['page'], 'file_name': row['file_name'], 'ref_id': row['ref_id'], 'vulnerability_cat': row['vulnerability_cat'], 'score': row['score']}
143
- )
144
-
145
- # Append the Document object to the documents list
146
- ls_dict.append(doc)
147
- return ls_dict
148
-
149
  def get_refs(docs, res):
150
  '''
151
  Parse response for engineered reference ids (refer to prompt template)
@@ -159,40 +163,44 @@ def get_refs(docs, res):
159
  # extract
160
  result_str = "" # Initialize an empty string to store the result
161
  for i in range(len(docs)):
162
- doc = docs[i].to_dict()
163
- ref_id = doc['meta']['ref_id']
164
  if ref_id in ref_ids:
165
- if doc['meta']['document'] == "Supplementary":
166
- result_str += "**Ref. " + str(ref_id) + " [" + doc['meta']['country'] + " " + doc['meta']['document'] + ':' + doc['meta']['file_name'] + ' p' + str(doc['meta']['page']) + '; vulnerabilities: ' + doc['meta']['vulnerability_cat'] + "]:** " + "*'" + doc['content'] + "'*<br> <br>" # Add <br> for a line break
167
  else:
168
- result_str += "**Ref. " + str(ref_id) + " [" + doc['meta']['country'] + " " + doc['meta']['document'] + ' p' + str(doc['meta']['page']) + '; vulnerabilities: ' + doc['meta']['vulnerability_cat'] + "]:** " + "*'" + doc['content'] + "'*<br> <br>" # Add <br> for a line break
169
 
170
  return result_str
171
 
172
  # define a special function for putting the prompt together (as we can't use haystack)
173
  def get_prompt(docs, input_query):
174
  base_prompt=prompt_template
175
- # Add the meta data for references
176
- context = ' - '.join(['&&& [ref. '+str(d.meta['ref_id'])+'] '+d.meta['document']+' &&&: '+d.content for d in docs])
177
  prompt = base_prompt+"; Context: "+context+"; Question: "+input_query+"; Answer:"
178
  return(prompt)
179
 
180
- def run_query(input_text, country, model_sel):
181
  # first call the retriever function using selected filters
182
- docs = get_docs(input_text, country=country,vulnerability_cat=vulnerabilities_cat)
183
  # model selector (not currently being used)
184
  if model_sel == "chatGPT":
185
  # instantiate ChatCompletion as a generator object (stream is set to True)
186
- response = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": get_prompt(docs, input_text)}], stream=True)
 
 
 
 
 
 
 
187
  # iterate through the streamed output
188
  report = []
189
- for chunk in response:
190
- # extract the object containing the text (totally different structure when streaming)
191
- chunk_message = chunk['choices'][0]['delta']
192
- # test to make sure there is text in the object (some don't have)
193
- if 'content' in chunk_message:
194
- report.append(chunk_message.content) # extract the message
195
- # add the latest text and merge it with all previous
196
  result = "".join(report).strip()
197
  res_box.success(result) # output to response text box
198
 
 
3
  import os
4
  import pkg_resources
5
 
6
+ # # Using this wacky hack to get around the massively ridicolous managed env loading order
7
+ # def is_installed(package_name, version):
8
+ # try:
9
+ # pkg = pkg_resources.get_distribution(package_name)
10
+ # return pkg.version == version
11
+ # except pkg_resources.DistributionNotFound:
12
+ # return False
13
+
14
+ # @st.cache_resource
15
+ # def install_packages():
16
+ # install_commands = []
17
+
18
+ # if not is_installed("spaces", "0.12.0"):
19
+ # install_commands.append("pip install spaces==0.12.0")
20
 
21
+ # if not is_installed("pydantic", "1.8.2"):
22
+ # install_commands.append("pip install pydantic==1.8.2")
23
 
24
+ # if install_commands:
25
+ # os.system(" && ".join(install_commands))
26
 
27
+ # # install packages if necessary
28
+ # # install_packages()
29
 
30
 
31
  import re
 
33
  from dotenv import load_dotenv
34
  import numpy as np
35
  import pandas as pd
36
+ import getpass
37
+ import os
38
+ from dotenv import load_dotenv, find_dotenv
39
+ from pinecone import Pinecone, ServerlessSpec
40
+ from langchain_pinecone import PineconeVectorStore
41
+ from langchain_huggingface import HuggingFaceEmbeddings
42
+ # from langchain_core.output_parsers import StrOutputParser
43
+ # from langchain_core.runnables import RunnablePassthrough
44
+ # from langchain_openai import ChatOpenAI
45
+ from langchain.docstore.document import Document
46
+ from openai import OpenAI
47
+
48
+ client = OpenAI(
49
+ organization='org-x0YBcOjkdPyf6ExxWCkmFHAj',
50
+ project='proj_40oH22n9XudeKL2rgka1IQ5B',
51
+ api_key='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'
52
+ )
53
+
54
+ pinecone_api_key = os.environ.get("PINECONE_API_KEY")
55
 
56
+ @st.cache_resource
57
+ def initialize_embeddings(model_name: str = "all-mpnet-base-v2"):
58
+ embeddings = HuggingFaceEmbeddings(model_name=model_name)
59
+ return embeddings
60
 
 
 
61
 
62
+ @st.cache_resource
63
+ def initialize_vector_store(pinecone_api_key: str, index_name: str):
64
+ # Initialize Pinecone
65
+ pc = Pinecone(api_key=pinecone_api_key)
66
+
67
+ # Access the index
68
+ index = pc.Index(index_name)
69
+
70
+ # Use the cached embeddings
71
+ embeddings = initialize_embeddings()
72
+
73
+ # Create the vector store
74
+ vector_store = PineconeVectorStore(index=index, embedding=embeddings, text_key='content')
75
+
76
+ return vector_store, embeddings
77
 
78
+ # Unpack the tuple into both vector_store and embeddings
79
+ vector_store, embeddings = initialize_vector_store(pinecone_api_key, index_name="cpv-full-southern-africa-test")
80
 
 
 
 
 
 
 
 
81
 
82
 
 
 
83
 
84
+ def get_docs(query, country = [], vulnerability_cat = []):
85
 
86
+ if not country:
87
+ country = "All Countries"
88
+ if not vulnerability_cat:
89
+ if country == "All Countries":
90
+ filters = None
91
+ else:
92
+ filters = {'country': {'$in': country}}
93
+ else:
94
+ if country == "All Countries":
95
+ filters = {'vulnerability_cat': {'$in': vulnerability_cat}}
96
+ else:
97
+ filters = {'country': {'$in': country},'vulnerability_cat': {'$in': vulnerability_cat}}
98
+
99
+
100
+ docs = vector_store.similarity_search_by_vector_with_score(
101
+ embeddings.embed_query(query),
102
+ k=20,
103
+ filter=filters,
104
  )
 
105
 
106
+ # Break out the key fields and convert to pandas for filtering
107
+ docs_dict = [{**x[0].metadata,"score":x[1],"content":x[0].page_content} for x in docs]
108
+ df_docs = pd.DataFrame(docs_dict)
109
+ # Get ourselves an index setup from which to base the source reference number from (in the prompt and matching afterwards)
110
+ df_docs = df_docs.reset_index()
111
+ df_docs['ref_id'] = df_docs.index + 1 # start the index at 1
112
+ # Convert back to Document format
113
+ ls_dict = []
114
+ # Iterate over df and add relevant fields to the dict object
115
+ for index, row in df_docs.iterrows():
116
+ # Create a Document object for each row
117
+ doc = Document(
118
+ page_content = row['content'],
119
+ metadata={'country': row['country'],'document': row['document'], 'page': row['page'], 'file_name': row['file_name'], 'ref_id': row['ref_id'], 'vulnerability_cat': row['vulnerability_cat'], 'score': row['score']}
120
+ )
121
+ # Append the Document object to the documents list
122
+ ls_dict.append(doc)
123
 
124
+ return ls_dict
125
 
 
 
 
 
 
 
 
126
 
127
  prompt_template="Answer the given question using the following documents. \
128
  Formulate your answer in the style of an academic report. \
 
150
  ]
151
 
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  def get_refs(docs, res):
154
  '''
155
  Parse response for engineered reference ids (refer to prompt template)
 
163
  # extract
164
  result_str = "" # Initialize an empty string to store the result
165
  for i in range(len(docs)):
166
+ ref_id = docs[i].metadata['ref_id']
 
167
  if ref_id in ref_ids:
168
+ if docs[i].metadata['document'] == "Supplementary":
169
+ result_str += "**Ref. " + str(ref_id) + " [" + docs[i].metadata['country'] + " " + docs[i].metadata['document'] + ':' + docs[i].metadata['file_name'] + ' p' + str(docs[i].metadata['page']) + '; vulnerabilities: ' + docs[i].metadata['vulnerability_cat'] + "]:** " + "*'" + docs[i].page_content + "'*<br> <br>" # Add <br> for a line break
170
  else:
171
+ result_str += "**Ref. " + str(ref_id) + " [" + docs[i].metadata['country'] + " " + docs[i].metadata['document'] + ' p' + str(docs[i].metadata['page']) + '; vulnerabilities: ' + docs[i].metadata['vulnerability_cat'] + "]:** " + "*'" + docs[i].page_content + "'*<br> <br>" # Add <br> for a line break
172
 
173
  return result_str
174
 
175
  # define a special function for putting the prompt together (as we can't use haystack)
176
  def get_prompt(docs, input_query):
177
  base_prompt=prompt_template
178
+ # Add the metadata data for references
179
+ context = ' - '.join(['&&& [ref. '+str(d.metadata['ref_id'])+'] '+d.metadata['document']+' &&&: '+d.page_content for d in docs])
180
  prompt = base_prompt+"; Context: "+context+"; Question: "+input_query+"; Answer:"
181
  return(prompt)
182
 
183
+ def run_query(query, country, model_sel):
184
  # first call the retriever function using selected filters
185
+ docs = get_docs(query, country=country,vulnerability_cat=vulnerabilities_cat)
186
  # model selector (not currently being used)
187
  if model_sel == "chatGPT":
188
  # instantiate ChatCompletion as a generator object (stream is set to True)
189
+ # response = openai.ChatCompletion.create(model="gpt-4o-mini-2024-07-18", messages=[{"role": "user", "content": get_prompt(docs, query)}], stream=True)
190
+
191
+
192
+ stream = client.chat.completions.create(
193
+ model="gpt-4o-mini-2024-07-18",
194
+ messages=[{"role": "user", "content": get_prompt(docs, query)}],
195
+ stream=True,
196
+ )
197
  # iterate through the streamed output
198
  report = []
199
+
200
+ for chunk in stream:
201
+ if chunk.choices[0].delta.content is not None:
202
+ # print(chunk.choices[0].delta.content, end="")
203
+ report.append(chunk.choices[0].delta.content)
 
 
204
  result = "".join(report).strip()
205
  res_box.success(result) # output to response text box
206
 
env ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ OPENAI_API_KEY="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"
2
+ HF_API_KEY="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"
3
+ PINECONE_API_KEY="c3f5717c-f43a-46d0-893e-02b44dbcf13b"
4
+ USER1_HASH="$2b$12$hZbOi6zKmQQWvvpcllds9uAB3ili66N0aQyPzuDctl7IkNhl226oG"
5
+ USER2_HASH="$2b$12$kWnArbA.2QTkpMv2yvE2J.7UJw0Fgc/3FH1k5JRqhjg.cvytriGt2"