Spaces:
Runtime error
Runtime error
import chainlit as cl | |
from llama_index import ServiceContext | |
from llama_index.node_parser.simple import SimpleNodeParser | |
from llama_index.langchain_helpers.text_splitter import TokenTextSplitter | |
from llama_index.llms import OpenAI | |
from llama_index.embeddings.openai import OpenAIEmbedding | |
from llama_index import VectorStoreIndex | |
from llama_index.vector_stores import ChromaVectorStore | |
from llama_index.storage.storage_context import StorageContext | |
import chromadb | |
from llama_index.readers.wikipedia import WikipediaReader | |
from llama_index import SQLDatabase | |
from llama_index.tools import FunctionTool | |
from llama_index.vector_stores.types import ( | |
VectorStoreInfo, | |
MetadataInfo, | |
ExactMatchFilter, | |
MetadataFilters, | |
) | |
from llama_index.retrievers import VectorIndexRetriever | |
from llama_index.query_engine import RetrieverQueryEngine | |
from typing import List, Tuple, Any | |
from pydantic import BaseModel, Field | |
from llama_index.agent import OpenAIAgent | |
import pandas as pd | |
from sqlalchemy import create_engine | |
from llama_index import SQLDatabase | |
from llama_index.indices.struct_store.sql_query import NLSQLTableQueryEngine | |
from llama_index.tools.query_engine import QueryEngineTool | |
#openai.api_key = os.environ["OPENAI_API_KEY"] | |
embed_model = OpenAIEmbedding() | |
chunk_size = 1000 | |
llm = OpenAI( | |
temperature=0, | |
model="gpt-3.5-turbo", | |
streaming=True | |
) | |
service_context = ServiceContext.from_defaults( | |
llm=llm, | |
chunk_size=chunk_size, | |
embed_model=embed_model | |
) | |
text_splitter = TokenTextSplitter( | |
chunk_size=chunk_size | |
) | |
node_parser = SimpleNodeParser( | |
text_splitter=text_splitter | |
) | |
chroma_client = chromadb.Client() | |
chroma_collection = chroma_client.create_collection("wikipedia_barbie_opp") | |
vector_store = ChromaVectorStore(chroma_collection=chroma_collection) | |
storage_context = StorageContext.from_defaults(vector_store=vector_store) | |
wiki_vector_index = VectorStoreIndex([], storage_context=storage_context, service_context=service_context) | |
movie_list = ["Barbie (film)", "Oppenheimer (film)"] | |
wiki_docs = WikipediaReader().load_data(pages=movie_list, auto_suggest=False) | |
top_k = 3 | |
class AutoRetrieveModel(BaseModel): | |
query: str = Field(..., description="natural language query string") | |
filter_key_list: List[str] = Field( | |
..., description="List of metadata filter field names" | |
) | |
filter_value_list: List[str] = Field( | |
..., | |
description=( | |
"List of metadata filter field values (corresponding to names specified in filter_key_list)" | |
) | |
) | |
def auto_retrieve_fn( | |
query: str, filter_key_list: List[str], filter_value_list: List[str] | |
): | |
"""Auto retrieval function. | |
Performs auto-retrieval from a vector database, and then applies a set of filters. | |
""" | |
query = query or "Query" | |
exact_match_filters = [ | |
ExactMatchFilter(key=k, value=v) | |
for k, v in zip(filter_key_list, filter_value_list) | |
] | |
retriever = VectorIndexRetriever( | |
wiki_vector_index, filters=MetadataFilters(filters=exact_match_filters), top_k=top_k | |
) | |
query_engine = RetrieverQueryEngine.from_args(retriever) | |
response = query_engine.query(query) | |
return str(response) | |
def rename(orig_author: str): | |
rename_dict = {"RetrievalQA": "Consulting The Llamaindex Tools"} | |
return rename_dict.get(orig_author, orig_author) | |
async def init(): | |
msg = cl.Message(content=f"Building Index...") | |
await msg.send() | |
for movie, wiki_doc in zip(movie_list, wiki_docs): | |
nodes = node_parser.get_nodes_from_documents([wiki_doc]) | |
for node in nodes: | |
node.metadata = {'title' : movie} | |
wiki_vector_index.insert_nodes(nodes) | |
top_k = 3 | |
vector_store_info = VectorStoreInfo( | |
content_info="semantic information about movies", | |
metadata_info=[MetadataInfo( | |
name="title", | |
type="str", | |
description="title of the movie, one of [Barbie (film), Oppenheimer (film)]", | |
)] | |
) | |
description = f"""\ | |
Use this tool to look up semantic information about films. | |
The vector database schema is given below: | |
{vector_store_info.json()} | |
""" | |
auto_retrieve_tool = FunctionTool.from_defaults( | |
fn=auto_retrieve_fn, | |
name="auto_retrieve_tool", | |
description=description, | |
fn_schema=AutoRetrieveModel, | |
) | |
agent = OpenAIAgent.from_tools( | |
[auto_retrieve_tool], llm=llm, verbose=True | |
) | |
barbie_df = pd.read_csv ('./data/barbie.csv') | |
oppenheimer_df = pd.read_csv ('./data/oppenheimer.csv') | |
engine = create_engine("sqlite+pysqlite:///:memory:") | |
barbie_df.to_sql( | |
"barbie", | |
engine | |
) | |
oppenheimer_df.to_sql( | |
"oppenheimer", | |
engine | |
) | |
sql_database = SQLDatabase( | |
engine, | |
include_tables=['barbie', 'oppenheimer']) | |
sql_query_engine = NLSQLTableQueryEngine( | |
sql_database=sql_database, | |
tables=['barbie', 'oppenheimer'] | |
) | |
sql_tool = QueryEngineTool.from_defaults( | |
query_engine=sql_query_engine, | |
name='sql_tool', | |
description=( | |
"Useful for translating a natural language query into a SQL query over a table containing: " + | |
"barbie, containing information related to reviews of the Barbie movie" + | |
"oppenheimer, containing information related to reviews of the Oppenheimer movie" | |
), | |
) | |
agent = OpenAIAgent.from_tools( | |
[sql_tool], llm=llm, verbose=True | |
) | |
barbenheimer_agent = OpenAIAgent.from_tools( | |
[auto_retrieve_tool, sql_tool], llm=llm, verbose=True | |
) | |
msg.content = f"Index built!" | |
await msg.send() | |
cl.user_session.set("barbenheimer_agent", barbenheimer_agent) | |
async def main(message): | |
barbenheimer_agent = cl.user_session.get("barbenheimer_agent") | |
response = barbenheimer_agent.chat(message) | |
await cl.Message(content=str(response)).send() | |