Spaces:
Sleeping
Sleeping
import chainlit as cl | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
from langchain.embeddings import CacheBackedEmbeddings | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.vectorstores import FAISS | |
from langchain.chains import RetrievalQA | |
from langchain.chat_models import ChatOpenAI | |
from langchain.storage import LocalFileStore | |
from langchain.prompts.chat import ( | |
ChatPromptTemplate, | |
SystemMessagePromptTemplate, | |
HumanMessagePromptTemplate, | |
) | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
system_template = """ | |
Use the following pieces of context to answer the user's question. If the question cannot be answered with the supplied context, simply answer "I cannot determine this based on the provided context." | |
---------------- | |
{context}""" | |
messages = [ | |
SystemMessagePromptTemplate.from_template(system_template), | |
HumanMessagePromptTemplate.from_template("{question}"), | |
] | |
prompt = ChatPromptTemplate(messages=messages) | |
chain_type_kwargs = {"prompt": prompt} | |
def rename(orig_author: str): | |
rename_dict = {"RetrievalQA": "Consulting PageTurn"} | |
return rename_dict.get(orig_author, orig_author) | |
async def init(): | |
msg = cl.Message(content=f"Building Index...") | |
await msg.send() | |
# Read text from a .txt file | |
with open('./data/aerodynamic_drag.txt', 'r') as f: | |
aerodynamic_drag_data = f.read() | |
# Split the text into smaller chunks | |
documents = text_splitter.transform_documents([aerodynamic_drag_data]) | |
# Create a local file store for caching | |
store = LocalFileStore("./cache/") | |
core_embeddings_model = OpenAIEmbeddings() | |
embedder = CacheBackedEmbeddings.from_bytes_store( | |
core_embeddings_model, store, namespace=core_embeddings_model.model | |
) | |
# Make async docsearch | |
docsearch = await cl.make_async(FAISS.from_documents)(documents, embedder) | |
chain = RetrievalQA.from_chain_type( | |
ChatOpenAI(model="gpt-4", temperature=0, streaming=True), | |
chain_type="stuff", | |
return_source_documents=True, | |
retriever=docsearch.as_retriever(), | |
chain_type_kwargs={"prompt": prompt} | |
) | |
msg.content = f"Index built!" | |
await msg.send() | |
cl.user_session.set("chain", chain) | |
async def main(message): | |
chain = cl.user_session.get("chain") | |
cb = cl.AsyncLangchainCallbackHandler( | |
stream_final_answer=False, answer_prefix_tokens=["FINAL", "ANSWER"] | |
) | |
cb.answer_reached = True | |
res = await chain.acall(message, callbacks=[cb], ) | |
answer = res["result"] | |
source_elements = [] | |
visited_sources = set() | |
# Get the documents from the user session | |
docs = res["source_documents"] | |
metadatas = [doc.metadata for doc in docs] | |
all_sources = [m["source"] for m in metadatas] | |
for source in all_sources: | |
if source in visited_sources: | |
continue | |
visited_sources.add(source) | |
# Create the text element referenced in the message | |
source_elements.append( | |
cl.Text(content="https://www.imdb.com" + source, name="Review URL") | |
) | |
if source_elements: | |
answer += f"\nSources: {', '.join([e.content.decode('utf-8') for e in source_elements])}" | |
else: | |
answer += "\nNo sources found" | |
await cl.Message(content=answer, elements=source_elements).send() | |