File size: 5,631 Bytes
2f6ea1f
 
 
7cc3907
8c18b31
 
 
7cc3907
cdda8d7
 
 
f7c72fc
 
cdda8d7
6c995cf
649e581
cdda8d7
 
f7c72fc
cdda8d7
 
 
 
 
 
dd49b84
2f6ea1f
7cc3907
cdda8d7
 
 
 
 
8c18b31
31cbd5c
8c18b31
014336b
 
 
 
cdda8d7
014336b
2f6ea1f
70c2a60
f7c72fc
70c2a60
31cbd5c
cdda8d7
6c995cf
70c2a60
f7c72fc
 
 
 
cdda8d7
31cbd5c
cdda8d7
 
 
 
 
31cbd5c
cdda8d7
 
 
 
 
70c2a60
cdda8d7
 
 
 
 
 
 
f7c72fc
 
cdda8d7
 
f7c72fc
 
 
cdda8d7
 
 
 
 
 
 
6c995cf
cdda8d7
 
dd49b84
 
 
cdda8d7
 
 
dd49b84
 
cdda8d7
 
 
 
31cbd5c
f7c72fc
8c18b31
7cc3907
014336b
dd49b84
cdda8d7
dd49b84
 
cdda8d7
dd49b84
 
 
cdda8d7
 
 
 
dd49b84
cdda8d7
649e581
01e3f20
cdda8d7
 
 
 
 
 
 
 
 
 
 
 
dd49b84
cdda8d7
dd49b84
 
 
 
cdda8d7
56875e8
cdda8d7
dd49b84
01e3f20
 
dd49b84
cdda8d7
dd49b84
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# You can find this code for Chainlit python streaming here (https://docs.chainlit.io/concepts/streaming/python)

# OpenAI Chat completion
import os
import chainlit as cl  # importing chainlit for our app
from chainlit.prompt import Prompt, PromptMessage  # importing prompt tools
from chainlit.playground.providers import ChatOpenAI  # importing ChatOpenAI tools
from dotenv import load_dotenv
import arxiv
import pinecone
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings.azure_openai import AzureOpenAIEmbeddings
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
from langchain.embeddings import CacheBackedEmbeddings
from langchain.storage import LocalFileStore, InMemoryStore
from utils.store import index_documents, search_and_index
from utils.chain import create_chain
from langchain.vectorstores import Pinecone

from langchain.schema.runnable import RunnableSequence
from langchain.schema import format_document
from pprint import pprint
from langchain_core.vectorstores import VectorStoreRetriever
import langchain
from langchain.cache import InMemoryCache
from langchain.memory import ConversationBufferMemory

load_dotenv()
YOUR_API_KEY = os.environ["PINECONE_API_KEY"]
YOUR_ENV = os.environ["PINECONE_ENV"]
INDEX_NAME= 'arxiv-paper-index'
WANDB_API_KEY=os.environ["WANDB_API_KEY"]
WANDB_PROJECT=os.environ["WANDB_PROJECT"]


@cl.on_chat_start  # marks a function that will be executed at the start of a user session
async def start_chat():
    settings = {
        "model": "gpt-3.5-turbo",
        "temperature": 0,
        "max_tokens": 500
    }

    await cl.Message(
        content="Hi, I am here to help you learn about a topic, what would you like to learn about today? 😊"
    ).send()

    # create an embedder through a cache interface (locally) (on start)
    store = InMemoryStore()

    core_embeddings_model = AzureOpenAIEmbeddings(
        api_key=os.environ['AZURE_OPENAI_API_KEY'],
        azure_deployment="text-embedding-ada-002",
        azure_endpoint=os.environ['AZURE_OPENAI_ENDPOINT']
    )

    embedder = CacheBackedEmbeddings.from_bytes_store(
        underlying_embeddings=core_embeddings_model,
        document_embedding_cache=store,
        namespace=core_embeddings_model.model
    )

    # instantiate pinecone (on start)
    pinecone.init(
        api_key=YOUR_API_KEY,
        environment=YOUR_ENV
    )

    if INDEX_NAME not in pinecone.list_indexes():
        pinecone.create_index(
            name=INDEX_NAME,
            metric='cosine',
            dimension=1536
        )
    index = pinecone.GRPCIndex(INDEX_NAME)

    llm = AzureChatOpenAI(
        temperature=settings['temperature'],
        max_tokens=settings['max_tokens'],
        api_key=os.environ['AZURE_OPENAI_API_KEY'],
        azure_deployment="gpt-35-turbo-16k",
        api_version="2023-07-01-preview",
        streaming=True
    )

    # create a prompt cache (locally) (on start)
    langchain.llm_cache = InMemoryCache()
    
    # log data in WaB (on start)
    os.environ["WANDB_MODE"] = "disabled"
    os.environ["LANGCHAIN_WANDB_TRACING"] = "true"

    # setup memory
    memory = ConversationBufferMemory(memory_key="chat_history")

    tools = {
        "index": index,
        "embedder": embedder,
        "llm": llm,
        "memory": memory
    }
    cl.user_session.set("tools", tools)
    cl.user_session.set("settings", settings)
    cl.user_session.set("first_run", False)


@cl.on_message  # marks a function that should be run each time the chatbot receives a message from a user
async def main(message: cl.Message):
    settings = cl.user_session.get("settings")
    tools: dict =  cl.user_session.get("tools")
    first_run = cl.user_session.get("first_run")
    retrieval_augmented_qa_chain = cl.user_session.get("chain", None)
    memory: ConversationBufferMemory = cl.user_session.get("memory")

    sys_message = cl.Message(content="")
    await sys_message.send() # renders a loader
    
    if not first_run:
        index: pinecone.GRPCIndex = tools['index']
        embedder: CacheBackedEmbeddings = tools['embedder']
        llm: ChatOpenAI = tools['llm']
        memory: ConversationBufferMemory = tools['memory']

        # using query search for ArXiv documents and index files(on message) 
        await cl.make_async(search_and_index)(message=message, quantity=10, embedder=embedder, index=index)

        text_field = "source_document"
        index = pinecone.Index(INDEX_NAME)
        vectorstore = Pinecone(
            index=index,
            embedding=embedder.embed_query,
            text_key=text_field
        )
        retriever: VectorStoreRetriever = vectorstore.as_retriever()

        # create the chain (on message)
        retrieval_augmented_qa_chain: RunnableSequence = create_chain(retriever=retriever, llm=llm)
        cl.user_session.set("chain", retrieval_augmented_qa_chain)
        
        sys_message.content = """
        I found some papers and studied them πŸ˜‰ \n"""
        await sys_message.update()
    
    # run
    async for chunk in retrieval_augmented_qa_chain.astream({"question": f"{message.content}", "chat_history": memory.buffer_as_messages}):
        if res:= chunk.get('response'):
            await sys_message.stream_token(res.content)
        if chunk.get("context"):
            pprint(chunk.get("context"))
    await sys_message.send()

    memory.chat_memory.add_user_message(message.content)
    memory.chat_memory.add_ai_message(sys_message.content)

    print(memory.buffer_as_str)


    cl.user_session.set("memory", memory)
    cl.user_session.set("first_run", True)