jseims commited on
Commit
1e690b2
1 Parent(s): b9a7468

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +177 -47
app.py CHANGED
@@ -1,64 +1,194 @@
1
- import os
2
- import openai
3
-
4
- from llama_index.query_engine.retriever_query_engine import RetrieverQueryEngine
5
- from llama_index.callbacks.base import CallbackManager
6
- from llama_index import (
7
- LLMPredictor,
8
- ServiceContext,
9
- StorageContext,
10
- load_index_from_storage,
11
- )
12
- from langchain.chat_models import ChatOpenAI
13
  import chainlit as cl
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- try:
16
- # rebuild storage context
17
- storage_context = StorageContext.from_defaults(persist_dir="./storage")
18
- # load index
19
- index = load_index_from_storage(storage_context)
20
- except:
21
- from llama_index import GPTVectorStoreIndex, SimpleDirectoryReader
22
 
23
- documents = SimpleDirectoryReader("./data").load_data()
24
- index = GPTVectorStoreIndex.from_documents(documents)
25
- index.storage_context.persist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
 
 
 
 
 
 
 
 
 
27
 
28
  @cl.on_chat_start
29
- async def factory():
30
- llm_predictor = LLMPredictor(
31
- llm=ChatOpenAI(
32
- temperature=0,
33
- model_name="gpt-3.5-turbo",
34
- streaming=True,
35
- ),
 
 
 
 
 
 
 
 
 
 
 
36
  )
37
- service_context = ServiceContext.from_defaults(
38
- llm_predictor=llm_predictor,
39
- chunk_size=512,
40
- callback_manager=CallbackManager([cl.LlamaIndexCallbackHandler()]),
 
 
 
 
 
 
 
 
41
  )
42
 
43
- query_engine = index.as_query_engine(
44
- service_context=service_context,
45
- streaming=True,
 
 
 
 
 
 
 
 
 
46
  )
47
 
48
- cl.user_session.set("query_engine", query_engine)
 
 
 
49
 
 
 
 
 
50
 
51
- @cl.on_message
52
- async def main(message):
53
- query_engine = cl.user_session.get("query_engine") # type: RetrieverQueryEngine
54
- response = await cl.make_async(query_engine.query)(message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- response_message = cl.Message(content="")
57
 
58
- for token in response.response_gen:
59
- await response_message.stream_token(token=token)
60
 
61
- if response.response_txt:
62
- response_message.content = response.response_txt
 
 
63
 
64
- await response_message.send()
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import chainlit as cl
2
+ from llama_index import ServiceContext
3
+ from llama_index.node_parser.simple import SimpleNodeParser
4
+ from llama_index.langchain_helpers.text_splitter import TokenTextSplitter
5
+ from llama_index.llms import OpenAI
6
+ from llama_index.embeddings.openai import OpenAIEmbedding
7
+ from llama_index import VectorStoreIndex
8
+ from llama_index.vector_stores import ChromaVectorStore
9
+ from llama_index.storage.storage_context import StorageContext
10
+ import chromadb
11
+ from llama_index.readers.wikipedia import WikipediaReader
12
+ from llama_index.tools import FunctionTool
13
+ from llama_index.vector_stores.types import (
14
+ VectorStoreInfo,
15
+ MetadataInfo,
16
+ ExactMatchFilter,
17
+ MetadataFilters,
18
+ )
19
+ from llama_index.retrievers import VectorIndexRetriever
20
+ from llama_index.query_engine import RetrieverQueryEngine
21
+
22
+ from typing import List, Tuple, Any
23
+ from pydantic import BaseModel, Field
24
+ from llama_index.agent import OpenAIAgent
25
+
26
+ import pandas as pd
27
+ from sqlalchemy import create_engine
28
+ from llama_index import SQLDatabase
29
+ from llama_index.indices.struct_store.sql_query import NLSQLTableQueryEngine
30
+ from llama_index.tools.query_engine import QueryEngineTool
31
+
32
+
33
+ openai.api_key = os.environ["OPENAI_API_KEY"]
34
+
35
+ embed_model = OpenAIEmbedding()
36
+ chunk_size = 1000
37
+ llm = OpenAI(
38
+ temperature=0,
39
+ model="gpt-3.5-turbo",
40
+ streaming=True
41
+ )
42
+
43
+ service_context = ServiceContext.from_defaults(
44
+ llm=llm,
45
+ chunk_size=chunk_size,
46
+ embed_model=embed_model
47
+ )
48
+
49
+ text_splitter = TokenTextSplitter(
50
+ chunk_size=chunk_size
51
+ )
52
+
53
+ node_parser = SimpleNodeParser(
54
+ text_splitter=text_splitter
55
+ )
56
+
57
+ chroma_client = chromadb.Client()
58
+ chroma_collection = chroma_client.create_collection("wikipedia_barbie_opp")
59
+
60
+ vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
61
+ storage_context = StorageContext.from_defaults(vector_store=vector_store)
62
+ wiki_vector_index = VectorStoreIndex([], storage_context=storage_context, service_context=service_context)
63
+
64
+ movie_list = ["Barbie (film)", "Oppenheimer (film)"]
65
 
66
+ wiki_docs = WikipediaReader().load_data(pages=movie_list, auto_suggest=False)
 
 
 
 
 
 
67
 
68
+ class AutoRetrieveModel(BaseModel):
69
+ query: str = Field(..., description="natural language query string")
70
+ filter_key_list: List[str] = Field(
71
+ ..., description="List of metadata filter field names"
72
+ )
73
+ filter_value_list: List[str] = Field(
74
+ ...,
75
+ description=(
76
+ "List of metadata filter field values (corresponding to names specified in filter_key_list)"
77
+ )
78
+ )
79
+
80
+ def auto_retrieve_fn(
81
+ query: str, filter_key_list: List[str], filter_value_list: List[str]
82
+ ):
83
+ """Auto retrieval function.
84
+ Performs auto-retrieval from a vector database, and then applies a set of filters.
85
+ """
86
+ query = query or "Query"
87
+
88
+ exact_match_filters = [
89
+ ExactMatchFilter(key=k, value=v)
90
+ for k, v in zip(filter_key_list, filter_value_list)
91
+ ]
92
+ retriever = VectorIndexRetriever(
93
+ wiki_vector_index, filters=MetadataFilters(filters=exact_match_filters), top_k=top_k
94
+ )
95
+ query_engine = RetrieverQueryEngine.from_args(retriever)
96
 
97
+ response = query_engine.query(query)
98
+ return str(response)
99
+
100
+
101
+
102
+ @cl.author_rename
103
+ def rename(orig_author: str):
104
+ rename_dict = {"RetrievalQA": "Consulting The Llamaindex Tools"}
105
+ return rename_dict.get(orig_author, orig_author)
106
 
107
  @cl.on_chat_start
108
+ async def init():
109
+ msg = cl.Message(content=f"Building Index...")
110
+ await msg.send()
111
+
112
+ for movie, wiki_doc in zip(movie_list, wiki_docs):
113
+ nodes = node_parser.get_nodes_from_documents([wiki_doc])
114
+ for node in nodes:
115
+ node.metadata = {'title' : movie}
116
+ wiki_vector_index.insert_nodes(nodes)
117
+
118
+ top_k = 3
119
+ vector_store_info = VectorStoreInfo(
120
+ content_info="semantic information about movies",
121
+ metadata_info=[MetadataInfo(
122
+ name="title",
123
+ type="str",
124
+ description="title of the movie, one of [Barbie (film), Oppenheimer (film)]",
125
+ )]
126
  )
127
+
128
+ description = f"""\
129
+ Use this tool to look up semantic information about films.
130
+ The vector database schema is given below:
131
+ {vector_store_info.json()}
132
+ """
133
+
134
+ auto_retrieve_tool = FunctionTool.from_defaults(
135
+ fn=auto_retrieve_fn,
136
+ name="auto_retrieve_tool",
137
+ description=description,
138
+ fn_schema=AutoRetrieveModel,
139
  )
140
 
141
+ agent = OpenAIAgent.from_tools(
142
+ [auto_retrieve_tool], llm=llm, verbose=True
143
+ )
144
+
145
+ barbie_df = pd.read_csv ('./data/barbie.csv')
146
+ oppenheimer_df = pd.read_csv ('./data/oppenheimer.csv')
147
+
148
+ engine = create_engine("sqlite+pysqlite:///:memory:")
149
+
150
+ barbie_df.to_sql(
151
+ "barbie",
152
+ engine
153
  )
154
 
155
+ oppenheimer_df.to_sql(
156
+ "oppenheimer",
157
+ engine
158
+ )
159
 
160
+ sql_query_engine = NLSQLTableQueryEngine(
161
+ sql_database=sql_database,
162
+ tables=['barbie', 'oppenheimer']
163
+ )
164
 
165
+ sql_tool = QueryEngineTool.from_defaults(
166
+ query_engine=sql_query_engine,
167
+ name='sql_tool',
168
+ description=(
169
+ "Useful for translating a natural language query into a SQL query over a table containing: " +
170
+ "barbie, containing information related to reviews of the Barbie movie" +
171
+ "oppenheimer, containing information related to reviews of the Oppenheimer movie"
172
+ ),
173
+ )
174
+
175
+ agent = OpenAIAgent.from_tools(
176
+ [sql_tool], llm=llm, verbose=True
177
+ )
178
+
179
+ barbenheimer_agent = OpenAIAgent.from_tools(
180
+ [auto_retrieve_tool, sql_tool], llm=llm, verbose=True
181
+ )
182
+
183
+ msg.content = f"Index built!"
184
+ await msg.send()
185
 
186
+ cl.user_session.set("barbenheimer_agent", barbenheimer_agent)
187
 
 
 
188
 
189
+ @cl.on_message
190
+ async def main(message):
191
+ barbenheimer_agent = cl.user_session.get("barbenheimer_agent")
192
+ response = barbenheimer_agent.chat(message)
193
 
194
+ await cl.Message(content=str(response)).send()