File size: 2,550 Bytes
b2b64bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from operator import itemgetter
import chainlit as cl
from langchain.schema.runnable import RunnablePassthrough
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
)

from utils import ArxivLoader, PineconeIndexer

system_template = """
Use the provided context to answer the user's query.

You may not answer the user's query unless there is specific context in the following text.

If you do not know the answer, or cannot answer, please respond with "I don't know".

Context:
{context}
"""

messages = [
    SystemMessagePromptTemplate.from_template(system_template),
    HumanMessagePromptTemplate.from_template("{question}"),
]

prompt = ChatPromptTemplate(messages=messages)
chain_type_kwargs = {"prompt": prompt}

@cl.author_rename
def rename(orig_author: str):
    rename_dict = {"RetrievalQA": "Learning about Nuclear Fission"}
    return rename_dict.get(orig_author, orig_author)

@cl.on_chat_start  # marks a function that will be executed at the start of a user session
async def start_chat():

    msg = cl.Message(content=f"Building Index...")
    await msg.send()

    # load documents from Arxiv
    axloader = ArxivLoader()
    axloader.main()

    # build index in Pinecone
    pi = PineconeIndexer()
    pi.load_embedder()
    pi.index_documents(axloader.documents)
    retriever=pi.get_vectorstore().as_retriever()
    print(pi.index.describe_index_stats())

    # build llm
    llm = ChatOpenAI(
        model="gpt-3.5-turbo",
        temperature=0
    )

    msg.content = f"Index built!"
    await msg.send()

    cl.user_session.set("llm", llm)
    cl.user_session.set("retriever", retriever)

@cl.on_message  # marks a function that should be run each time the chatbot receives a message from a user
async def main(message: cl.Message):

    llm = cl.user_session.get("llm")
    retriever = cl.user_session.get("retriever")

    retrieval_augmented_qa_chain = (
        {"context": itemgetter("question") | retriever,
        "question": itemgetter("question")
        }
        | RunnablePassthrough.assign(
            context=itemgetter("context")
        )
        | {
            "response": prompt  | llm,
            "context": itemgetter("context"),
        }
    )

    answer = retrieval_augmented_qa_chain.invoke({"question" : message.content})
    
    await cl.Message(content=answer["response"].content).send()