samlonka commited on
Commit
1ec4e65
·
1 Parent(s): 5862201

'tool_added'

Browse files
Files changed (1) hide show
  1. vector_tool.py +93 -0
vector_tool.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import pickle
4
+ import streamlit as st
5
+ from dotenv import load_dotenv
6
+ from pinecone import Pinecone, ServerlessSpec
7
+ from utils import load_pickle, initialize_embedding_model
8
+ from langchain_community.retrievers import BM25Retriever
9
+ from langchain_pinecone import PineconeVectorStore
10
+ from langchain.retrievers import EnsembleRetriever
11
+ from langchain.tools.retriever import create_retriever_tool
12
+
13
+
14
+
15
+ # Load .env file
16
+ load_dotenv()
17
+
18
+ # Constants
19
+ INDEX_NAME = "veda-index-v2"
20
+ MODEL_NAME = "BAAI/bge-large-en-v1.5"
21
+ DOCS_DIRECTORY = r"Docs\ramana_docs_ids.pkl"
22
+
23
+
24
+
25
+ # Initialize Pinecone client
26
+ PINECONE_API_KEY = os.getenv("PINECONE_API_KEY_SAM")
27
+ pc = Pinecone(api_key=PINECONE_API_KEY)
28
+
29
+ #@st.cache_resource
30
+ def create_or_load_index():
31
+ # Check if index already exists
32
+ if INDEX_NAME not in pc.list_indexes().names():
33
+ # Create index if it does not exist
34
+ pc.create_index(
35
+ INDEX_NAME,
36
+ dimension=1024,
37
+ metric='dotproduct',
38
+ spec=ServerlessSpec(
39
+ cloud="aws",
40
+ region="us-east-1"
41
+ )
42
+ )
43
+ # Wait for index to be initialized
44
+ while not pc.describe_index(INDEX_NAME).status['ready']:
45
+ time.sleep(1)
46
+ # Connect to index
47
+ return pc.Index(INDEX_NAME)
48
+
49
+ # Load documents
50
+ docs = load_pickle(DOCS_DIRECTORY)
51
+ # Initialize embedding model
52
+ embedding = initialize_embedding_model(MODEL_NAME)
53
+ # Create or load index
54
+ index = create_or_load_index()
55
+
56
+ # Initialize BM25 retriever
57
+ bm25_retriever = BM25Retriever.from_texts(
58
+ [text['document'].page_content for text in docs],
59
+ metadatas=[text['document'].metadata for text in docs]
60
+ )
61
+ bm25_retriever.k = 2
62
+
63
+ # Switch back to normal index for LangChain
64
+ vector_store = PineconeVectorStore(index, embedding)
65
+ retriever = vector_store.as_retriever(search_type="mmr")
66
+
67
+ # Initialize the ensemble retriever
68
+ ensemble_retriever = EnsembleRetriever(
69
+ retrievers=[bm25_retriever, retriever], weights=[0.2, 0.8]
70
+ )
71
+
72
+
73
+ vector_tools = create_retriever_tool(
74
+ retriever = ensemble_retriever,
75
+ name = "vector_retrieve",
76
+ description="Search and return documents related user query from the vector index.",
77
+ )
78
+
79
+ from langchain import hub
80
+
81
+ prompt = hub.pull("hwchase17/openai-tools-agent")
82
+ prompt.messages
83
+
84
+ from langchain.agents import AgentExecutor, create_openai_tools_agent
85
+ from langchain_openai import ChatOpenAI
86
+ import streamlit as st
87
+
88
+ os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
89
+ #load llm model
90
+ llm_AI4 = ChatOpenAI(model="gpt-4-1106-preview", temperature=0)
91
+
92
+ agent = create_openai_tools_agent(llm_AI4, [vector_tools], prompt)
93
+ agent_executor = AgentExecutor(agent=agent, tools=[vector_tools])