Spaces:
Sleeping
Sleeping
samlonka
commited on
Commit
·
1ec4e65
1
Parent(s):
5862201
'tool_added'
Browse files- vector_tool.py +93 -0
vector_tool.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import pickle
|
4 |
+
import streamlit as st
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
from pinecone import Pinecone, ServerlessSpec
|
7 |
+
from utils import load_pickle, initialize_embedding_model
|
8 |
+
from langchain_community.retrievers import BM25Retriever
|
9 |
+
from langchain_pinecone import PineconeVectorStore
|
10 |
+
from langchain.retrievers import EnsembleRetriever
|
11 |
+
from langchain.tools.retriever import create_retriever_tool
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
# Load .env file
|
16 |
+
load_dotenv()
|
17 |
+
|
18 |
+
# Constants
|
19 |
+
INDEX_NAME = "veda-index-v2"
|
20 |
+
MODEL_NAME = "BAAI/bge-large-en-v1.5"
|
21 |
+
DOCS_DIRECTORY = r"Docs\ramana_docs_ids.pkl"
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
# Initialize Pinecone client
|
26 |
+
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY_SAM")
|
27 |
+
pc = Pinecone(api_key=PINECONE_API_KEY)
|
28 |
+
|
29 |
+
#@st.cache_resource
|
30 |
+
def create_or_load_index():
|
31 |
+
# Check if index already exists
|
32 |
+
if INDEX_NAME not in pc.list_indexes().names():
|
33 |
+
# Create index if it does not exist
|
34 |
+
pc.create_index(
|
35 |
+
INDEX_NAME,
|
36 |
+
dimension=1024,
|
37 |
+
metric='dotproduct',
|
38 |
+
spec=ServerlessSpec(
|
39 |
+
cloud="aws",
|
40 |
+
region="us-east-1"
|
41 |
+
)
|
42 |
+
)
|
43 |
+
# Wait for index to be initialized
|
44 |
+
while not pc.describe_index(INDEX_NAME).status['ready']:
|
45 |
+
time.sleep(1)
|
46 |
+
# Connect to index
|
47 |
+
return pc.Index(INDEX_NAME)
|
48 |
+
|
49 |
+
# Load documents
|
50 |
+
docs = load_pickle(DOCS_DIRECTORY)
|
51 |
+
# Initialize embedding model
|
52 |
+
embedding = initialize_embedding_model(MODEL_NAME)
|
53 |
+
# Create or load index
|
54 |
+
index = create_or_load_index()
|
55 |
+
|
56 |
+
# Initialize BM25 retriever
|
57 |
+
bm25_retriever = BM25Retriever.from_texts(
|
58 |
+
[text['document'].page_content for text in docs],
|
59 |
+
metadatas=[text['document'].metadata for text in docs]
|
60 |
+
)
|
61 |
+
bm25_retriever.k = 2
|
62 |
+
|
63 |
+
# Switch back to normal index for LangChain
|
64 |
+
vector_store = PineconeVectorStore(index, embedding)
|
65 |
+
retriever = vector_store.as_retriever(search_type="mmr")
|
66 |
+
|
67 |
+
# Initialize the ensemble retriever
|
68 |
+
ensemble_retriever = EnsembleRetriever(
|
69 |
+
retrievers=[bm25_retriever, retriever], weights=[0.2, 0.8]
|
70 |
+
)
|
71 |
+
|
72 |
+
|
73 |
+
vector_tools = create_retriever_tool(
|
74 |
+
retriever = ensemble_retriever,
|
75 |
+
name = "vector_retrieve",
|
76 |
+
description="Search and return documents related user query from the vector index.",
|
77 |
+
)
|
78 |
+
|
79 |
+
from langchain import hub
|
80 |
+
|
81 |
+
prompt = hub.pull("hwchase17/openai-tools-agent")
|
82 |
+
prompt.messages
|
83 |
+
|
84 |
+
from langchain.agents import AgentExecutor, create_openai_tools_agent
|
85 |
+
from langchain_openai import ChatOpenAI
|
86 |
+
import streamlit as st
|
87 |
+
|
88 |
+
os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
|
89 |
+
#load llm model
|
90 |
+
llm_AI4 = ChatOpenAI(model="gpt-4-1106-preview", temperature=0)
|
91 |
+
|
92 |
+
agent = create_openai_tools_agent(llm_AI4, [vector_tools], prompt)
|
93 |
+
agent_executor = AgentExecutor(agent=agent, tools=[vector_tools])
|