mencraft commited on
Commit
9231bf7
·
1 Parent(s): e3f5fd9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -93
app.py CHANGED
@@ -1,110 +1,80 @@
1
- import chainlit as cl
2
- from langchain.embeddings.openai import OpenAIEmbeddings
3
- from langchain.document_loaders.csv_loader import CSVLoader
4
- from langchain.embeddings import CacheBackedEmbeddings
5
- from langchain.text_splitter import RecursiveCharacterTextSplitter
6
- from langchain.vectorstores import FAISS
7
- from langchain.chains import RetrievalQA
8
- from langchain.chat_models import ChatOpenAI
9
- from langchain.storage import LocalFileStore
10
- from langchain.prompts.chat import (
11
- ChatPromptTemplate,
12
- SystemMessagePromptTemplate,
13
- HumanMessagePromptTemplate,
14
  )
 
 
 
15
  import chainlit as cl
16
 
17
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
18
 
19
- system_template = """
20
- Use the following pieces of context to answer the user's question.
21
- Please respond as if you were Ken from the movie Barbie. Ken is a well-meaning but naive character who loves to Beach. He talks like a typical Californian Beach Bro, but he doesn't use the word "Dude" so much.
22
- If you don't know the answer, just say that you don't know, don't try to make up an answer.
23
- You can make inferences based on the context as long as it still faithfully represents the feedback.
24
 
25
- Example of your response should be:
 
 
 
 
 
 
26
 
27
- ```
28
- The answer is foo
29
- ```
 
 
 
30
 
31
- Begin!
32
- ----------------
33
- {context}"""
34
 
35
- messages = [
36
- SystemMessagePromptTemplate.from_template(system_template),
37
- HumanMessagePromptTemplate.from_template("{question}"),
38
- ]
39
- prompt = ChatPromptTemplate(messages=messages)
40
- chain_type_kwargs = {"prompt": prompt}
41
-
42
- @cl.author_rename
43
- def rename(orig_author: str):
44
- rename_dict = {"RetrievalQA": "Consulting The Kens"}
45
- return rename_dict.get(orig_author, orig_author)
46
 
47
  @cl.on_chat_start
48
- async def init():
49
- msg = cl.Message(content=f"Building Index...")
50
- await msg.send()
51
-
52
- # build FAISS index from csv
53
- loader = CSVLoader(file_path="./data/barbie.csv", source_column="Review_Url")
54
- data = loader.load()
55
- documents = text_splitter.transform_documents(data)
56
- store = LocalFileStore("./cache/")
57
- core_embeddings_model = OpenAIEmbeddings()
58
- embedder = CacheBackedEmbeddings.from_bytes_store(
59
- core_embeddings_model, store, namespace=core_embeddings_model.model
60
- )
61
- # make async docsearch
62
- docsearch = await cl.make_async(FAISS.from_documents)(documents, embedder)
63
-
64
- chain = RetrievalQA.from_chain_type(
65
- ChatOpenAI(model="gpt-4", temperature=0, streaming=True),
66
- chain_type="stuff",
67
- return_source_documents=True,
68
- retriever=docsearch.as_retriever(),
69
- chain_type_kwargs = {"prompt": prompt}
70
  )
71
 
72
- msg.content = f"Index built!"
73
- await msg.send()
 
74
 
75
- cl.user_session.set("chain", chain)
76
 
77
 
78
  @cl.on_message
79
  async def main(message):
80
- chain = cl.user_session.get("chain")
81
- cb = cl.AsyncLangchainCallbackHandler(
82
- stream_final_answer=False, answer_prefix_tokens=["FINAL", "ANSWER"]
83
- )
84
- cb.answer_reached = True
85
- res = await chain.acall(message, callbacks=[cb], )
86
-
87
- answer = res["result"]
88
- source_elements = []
89
- visited_sources = set()
90
-
91
- # Get the documents from the user session
92
- docs = res["source_documents"]
93
- metadatas = [doc.metadata for doc in docs]
94
- all_sources = [m["source"] for m in metadatas]
95
-
96
- for source in all_sources:
97
- if source in visited_sources:
98
- continue
99
- visited_sources.add(source)
100
- # Create the text element referenced in the message
101
- source_elements.append(
102
- cl.Text(content="https://www.imdb.com" + source, name="Review URL")
103
- )
104
-
105
- if source_elements:
106
- answer += f"\nSources: {', '.join([e.content.decode('utf-8') for e in source_elements])}"
107
- else:
108
- answer += "\nNo sources found"
109
-
110
- await cl.Message(content=answer, elements=source_elements).send()
 
1
+ import os
2
+ import openai
3
+
4
+ from llama_index.query_engine.retriever_query_engine import RetrieverQueryEngine
5
+ from llama_index.callbacks.base import CallbackManager
6
+ from llama_index import (
7
+ LLMPredictor,
8
+ ServiceContext,
9
+ SimpleDirectoryReader,
10
+ StorageContext,
11
+ load_index_from_storage,
 
 
12
  )
13
+ from langchain.chat_models import ChatOpenAI
14
+ from llama_index.llms import OpenAI
15
+ from llama_index import VectorStoreIndex
16
  import chainlit as cl
17
 
 
18
 
19
+ openai.api_key = os.environ.get("OPENAI_API_KEY")
 
 
 
 
20
 
21
+ # try:
22
+ # # rebuild storage context
23
+ # storage_context = StorageContext.from_defaults(persist_dir="./storage")
24
+ # # load index
25
+ # index = load_index_from_storage(storage_context)
26
+ # except:
27
+ # from llama_index import GPTVectorStoreIndex, SimpleDirectoryReader
28
 
29
+ # documents = SimpleDirectoryReader("./data").load_data()
30
+ # index = GPTVectorStoreIndex.from_documents(documents)
31
+ # index.storage_context.persist()
32
+ documents = SimpleDirectoryReader(
33
+ input_files=["hitchhikers.pdf"]
34
+ ).load_data()
35
 
36
+ index = VectorStoreIndex.from_documents(documents)
 
 
37
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  @cl.on_chat_start
40
+ async def factory():
41
+ # llm_predictor = LLMPredictor(
42
+ # llm=ChatOpenAI(
43
+ # temperature=0,
44
+ # model_name="gpt-3.5-turbo",
45
+ # streaming=True,
46
+ # ),
47
+ # )
48
+ # service_context = ServiceContext.from_defaults(
49
+ # llm_predictor=llm_predictor,
50
+ # chunk_size=512,
51
+ # callback_manager=CallbackManager([cl.LlamaIndexCallbackHandler()]),
52
+ # )
53
+
54
+ gpt_35_context = ServiceContext.from_defaults(
55
+ llm=OpenAI(model="gpt-3.5-turbo", temperature=0.3),
56
+ context_window=2048, # limit the context window artifically to test refine process
57
+ callback_manager=CallbackManager([cl.LlamaIndexCallbackHandler()]),
 
 
 
 
58
  )
59
 
60
+ query_engine = index.as_query_engine(
61
+ service_context=gpt_35_context
62
+ )
63
 
64
+ cl.user_session.set("query_engine", query_engine)
65
 
66
 
67
  @cl.on_message
68
  async def main(message):
69
+ query_engine = cl.user_session.get("query_engine") # type: RetrieverQueryEngine
70
+ response = await cl.make_async(query_engine.query)(message)
71
+ print(response)
72
+ response_message = cl.Message(content="")
73
+
74
+ # for token in response.response_gen:
75
+ # await response_message.stream_token(token=token)
76
+
77
+ # if response.response_txt:
78
+ response_message.content = response
79
+
80
+ await response_message.send()