File size: 3,427 Bytes
efb3d62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3769bb
efb3d62
 
 
 
 
 
 
 
 
 
 
 
c3769bb
efb3d62
 
 
 
 
 
 
c3769bb
 
 
 
 
 
 
 
efb3d62
 
 
 
 
c3769bb
 
efb3d62
 
 
 
 
 
 
c3769bb
efb3d62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import chainlit as cl
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings import CacheBackedEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.storage import LocalFileStore
from langchain.prompts.chat import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
)

text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)

system_template = """
Use the following pieces of context to answer the user's question. If the question cannot be answered with the supplied context, simply answer "I cannot determine this based on the provided context."  
----------------
{context}"""

messages = [
    SystemMessagePromptTemplate.from_template(system_template),
    HumanMessagePromptTemplate.from_template("{question}"),
]
prompt = ChatPromptTemplate(messages=messages)
chain_type_kwargs = {"prompt": prompt}

@cl.author_rename
def rename(orig_author: str):
    rename_dict = {"RetrievalQA": "Consulting PageTurn"}
    return rename_dict.get(orig_author, orig_author)

@cl.on_chat_start
async def init():
    msg = cl.Message(content=f"Building Index...")
    await msg.send()

    # Read text from a .txt file
    with open('./data/aerodynamic_drag.txt', 'r') as f:
        aerodynamic_drag_data = f.read()

    # Split the text into smaller chunks
    documents = text_splitter.transform_documents([aerodynamic_drag_data])

    # Create a local file store for caching
    store = LocalFileStore("./cache/")
    core_embeddings_model = OpenAIEmbeddings()
    embedder = CacheBackedEmbeddings.from_bytes_store(
        core_embeddings_model, store, namespace=core_embeddings_model.model
    )

    # Make async docsearch
    docsearch = await cl.make_async(FAISS.from_documents)(documents, embedder)

    chain = RetrievalQA.from_chain_type(
        ChatOpenAI(model="gpt-4", temperature=0, streaming=True),
        chain_type="stuff",
        return_source_documents=True,
        retriever=docsearch.as_retriever(),
        chain_type_kwargs={"prompt": prompt}
    )

    msg.content = f"Index built!"
    await msg.send()

    cl.user_session.set("chain", chain)

@cl.on_message
async def main(message):
    chain = cl.user_session.get("chain")
    cb = cl.AsyncLangchainCallbackHandler(
        stream_final_answer=False, answer_prefix_tokens=["FINAL", "ANSWER"]
    )
    cb.answer_reached = True
    res = await chain.acall(message, callbacks=[cb], )

    answer = res["result"]
    source_elements = []
    visited_sources = set()

    # Get the documents from the user session
    docs = res["source_documents"]
    metadatas = [doc.metadata for doc in docs]
    all_sources = [m["source"] for m in metadatas]

    for source in all_sources:
        if source in visited_sources:
            continue
        visited_sources.add(source)
        # Create the text element referenced in the message
        source_elements.append(
            cl.Text(content="https://www.imdb.com" + source, name="Review URL")
        )

    if source_elements:
        answer += f"\nSources: {', '.join([e.content.decode('utf-8') for e in source_elements])}"
    else:
        answer += "\nNo sources found"

    await cl.Message(content=answer, elements=source_elements).send()