jseims's picture
Update app.py
1e690b2
raw
history blame
5.87 kB
import chainlit as cl
from llama_index import ServiceContext
from llama_index.node_parser.simple import SimpleNodeParser
from llama_index.langchain_helpers.text_splitter import TokenTextSplitter
from llama_index.llms import OpenAI
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index import VectorStoreIndex
from llama_index.vector_stores import ChromaVectorStore
from llama_index.storage.storage_context import StorageContext
import chromadb
from llama_index.readers.wikipedia import WikipediaReader
from llama_index.tools import FunctionTool
from llama_index.vector_stores.types import (
VectorStoreInfo,
MetadataInfo,
ExactMatchFilter,
MetadataFilters,
)
from llama_index.retrievers import VectorIndexRetriever
from llama_index.query_engine import RetrieverQueryEngine
from typing import List, Tuple, Any
from pydantic import BaseModel, Field
from llama_index.agent import OpenAIAgent
import pandas as pd
from sqlalchemy import create_engine
from llama_index import SQLDatabase
from llama_index.indices.struct_store.sql_query import NLSQLTableQueryEngine
from llama_index.tools.query_engine import QueryEngineTool
openai.api_key = os.environ["OPENAI_API_KEY"]
embed_model = OpenAIEmbedding()
chunk_size = 1000
llm = OpenAI(
temperature=0,
model="gpt-3.5-turbo",
streaming=True
)
service_context = ServiceContext.from_defaults(
llm=llm,
chunk_size=chunk_size,
embed_model=embed_model
)
text_splitter = TokenTextSplitter(
chunk_size=chunk_size
)
node_parser = SimpleNodeParser(
text_splitter=text_splitter
)
chroma_client = chromadb.Client()
chroma_collection = chroma_client.create_collection("wikipedia_barbie_opp")
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
wiki_vector_index = VectorStoreIndex([], storage_context=storage_context, service_context=service_context)
movie_list = ["Barbie (film)", "Oppenheimer (film)"]
wiki_docs = WikipediaReader().load_data(pages=movie_list, auto_suggest=False)
class AutoRetrieveModel(BaseModel):
query: str = Field(..., description="natural language query string")
filter_key_list: List[str] = Field(
..., description="List of metadata filter field names"
)
filter_value_list: List[str] = Field(
...,
description=(
"List of metadata filter field values (corresponding to names specified in filter_key_list)"
)
)
def auto_retrieve_fn(
query: str, filter_key_list: List[str], filter_value_list: List[str]
):
"""Auto retrieval function.
Performs auto-retrieval from a vector database, and then applies a set of filters.
"""
query = query or "Query"
exact_match_filters = [
ExactMatchFilter(key=k, value=v)
for k, v in zip(filter_key_list, filter_value_list)
]
retriever = VectorIndexRetriever(
wiki_vector_index, filters=MetadataFilters(filters=exact_match_filters), top_k=top_k
)
query_engine = RetrieverQueryEngine.from_args(retriever)
response = query_engine.query(query)
return str(response)
@cl.author_rename
def rename(orig_author: str):
rename_dict = {"RetrievalQA": "Consulting The Llamaindex Tools"}
return rename_dict.get(orig_author, orig_author)
@cl.on_chat_start
async def init():
msg = cl.Message(content=f"Building Index...")
await msg.send()
for movie, wiki_doc in zip(movie_list, wiki_docs):
nodes = node_parser.get_nodes_from_documents([wiki_doc])
for node in nodes:
node.metadata = {'title' : movie}
wiki_vector_index.insert_nodes(nodes)
top_k = 3
vector_store_info = VectorStoreInfo(
content_info="semantic information about movies",
metadata_info=[MetadataInfo(
name="title",
type="str",
description="title of the movie, one of [Barbie (film), Oppenheimer (film)]",
)]
)
description = f"""\
Use this tool to look up semantic information about films.
The vector database schema is given below:
{vector_store_info.json()}
"""
auto_retrieve_tool = FunctionTool.from_defaults(
fn=auto_retrieve_fn,
name="auto_retrieve_tool",
description=description,
fn_schema=AutoRetrieveModel,
)
agent = OpenAIAgent.from_tools(
[auto_retrieve_tool], llm=llm, verbose=True
)
barbie_df = pd.read_csv ('./data/barbie.csv')
oppenheimer_df = pd.read_csv ('./data/oppenheimer.csv')
engine = create_engine("sqlite+pysqlite:///:memory:")
barbie_df.to_sql(
"barbie",
engine
)
oppenheimer_df.to_sql(
"oppenheimer",
engine
)
sql_query_engine = NLSQLTableQueryEngine(
sql_database=sql_database,
tables=['barbie', 'oppenheimer']
)
sql_tool = QueryEngineTool.from_defaults(
query_engine=sql_query_engine,
name='sql_tool',
description=(
"Useful for translating a natural language query into a SQL query over a table containing: " +
"barbie, containing information related to reviews of the Barbie movie" +
"oppenheimer, containing information related to reviews of the Oppenheimer movie"
),
)
agent = OpenAIAgent.from_tools(
[sql_tool], llm=llm, verbose=True
)
barbenheimer_agent = OpenAIAgent.from_tools(
[auto_retrieve_tool, sql_tool], llm=llm, verbose=True
)
msg.content = f"Index built!"
await msg.send()
cl.user_session.set("barbenheimer_agent", barbenheimer_agent)
@cl.on_message
async def main(message):
barbenheimer_agent = cl.user_session.get("barbenheimer_agent")
response = barbenheimer_agent.chat(message)
await cl.Message(content=str(response)).send()