import chainlit as cl from langchain.embeddings.openai import OpenAIEmbeddings from langchain.embeddings import CacheBackedEmbeddings from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores import FAISS from langchain.chains import RetrievalQA from langchain.chat_models import ChatOpenAI from langchain.storage import LocalFileStore from langchain.prompts.chat import ( ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate, ) text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) system_template = """ Use the following pieces of context to answer the user's question. If the question cannot be answered with the supplied context, simply answer "I cannot determine this based on the provided context." ---------------- {context}""" messages = [ SystemMessagePromptTemplate.from_template(system_template), HumanMessagePromptTemplate.from_template("{question}"), ] prompt = ChatPromptTemplate(messages=messages) chain_type_kwargs = {"prompt": prompt} @cl.author_rename def rename(orig_author: str): rename_dict = {"RetrievalQA": "Consulting PageTurn"} return rename_dict.get(orig_author, orig_author) @cl.on_chat_start async def init(): msg = cl.Message(content=f"Building Index...") await msg.send() # Read text from a .txt file with open('./data/aerodynamic_drag.txt', 'r') as f: aerodynamic_drag_data = f.read() # Split the text into smaller chunks documents = text_splitter.transform_documents([aerodynamic_drag_data]) # Create a local file store for caching store = LocalFileStore("./cache/") core_embeddings_model = OpenAIEmbeddings() embedder = CacheBackedEmbeddings.from_bytes_store( core_embeddings_model, store, namespace=core_embeddings_model.model ) # Make async docsearch docsearch = await cl.make_async(FAISS.from_documents)(documents, embedder) chain = RetrievalQA.from_chain_type( ChatOpenAI(model="gpt-4", temperature=0, streaming=True), chain_type="stuff", return_source_documents=True, retriever=docsearch.as_retriever(), chain_type_kwargs={"prompt": prompt} ) msg.content = f"Index built!" await msg.send() cl.user_session.set("chain", chain) @cl.on_message async def main(message): chain = cl.user_session.get("chain") cb = cl.AsyncLangchainCallbackHandler( stream_final_answer=False, answer_prefix_tokens=["FINAL", "ANSWER"] ) cb.answer_reached = True res = await chain.acall(message, callbacks=[cb], ) answer = res["result"] source_elements = [] visited_sources = set() # Get the documents from the user session docs = res["source_documents"] metadatas = [doc.metadata for doc in docs] all_sources = [m["source"] for m in metadatas] for source in all_sources: if source in visited_sources: continue visited_sources.add(source) # Create the text element referenced in the message source_elements.append( cl.Text(content="https://www.imdb.com" + source, name="Review URL") ) if source_elements: answer += f"\nSources: {', '.join([e.content.decode('utf-8') for e in source_elements])}" else: answer += "\nNo sources found" await cl.Message(content=answer, elements=source_elements).send()