File size: 3,052 Bytes
6acb7ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c827140
 
6acb7ee
 
 
 
 
7da07a5
c827140
27e491c
 
 
c827140
6acb7ee
c827140
6acb7ee
 
27e491c
 
 
6acb7ee
 
 
27e491c
 
6acb7ee
 
 
 
 
 
 
 
 
c827140
 
6acb7ee
 
 
 
 
 
27e491c
6acb7ee
 
 
 
 
27e491c
 
 
 
 
 
 
 
6acb7ee
 
 
27e491c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import chainlit as cl
from dotenv import load_dotenv
from operator import itemgetter
from langchain_huggingface import HuggingFaceEndpoint
from langchain_community.document_loaders import TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEndpointEmbeddings
from langchain_core.prompts import PromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from langchain.schema.runnable.config import RunnableConfig

# GLOBAL SCOPE - ENTIRE APPLICATION HAS ACCESS TO VALUES SET IN THIS SCOPE #
load_dotenv()

HF_LLM_ENDPOINT = os.environ["HF_LLM_ENDPOINT"]
HF_EMBED_ENDPOINT = os.environ["HF_EMBED_ENDPOINT"]
HF_TOKEN = os.environ["HF_TOKEN"]

vectorstore_path = "./data/vectorstore"
index_file = os.path.join(vectorstore_path, "index.faiss")
hf_embeddings = HuggingFaceEndpointEmbeddings(
    model=HF_EMBED_ENDPOINT,
    task="feature-extraction",
    huggingfacehub_api_token=HF_TOKEN,
)

vectorstore = FAISS.load_local(
    vectorstore_path, 
    hf_embeddings, 
    allow_dangerous_deserialization=True
)
hf_retriever = vectorstore.as_retriever()
print("Loaded Vectorstore")

RAG_PROMPT_TEMPLATE = """\
system
You are a helpful assistant. You answer user questions based on provided context. If you can't answer the question with the provided context, say you don't know.
user
User Query:
{query}
Context:
{context}
assistant
"""

rag_prompt = PromptTemplate.from_template(RAG_PROMPT_TEMPLATE)

hf_llm = HuggingFaceEndpoint(
    endpoint_url=HF_LLM_ENDPOINT,
    max_new_tokens=512,
    top_k=10,
    top_p=0.95,
    temperature=0.1,
    repetition_penalty=1.0,
    huggingfacehub_api_token=HF_TOKEN,
)

@cl.author_rename
def rename(original_author: str):
    rename_dict = {
        "Assistant": "Paul Graham Essay Bot"
    }
    return rename_dict.get(original_author, original_author)

@cl.on_chat_start
async def start_chat():
    try:
        lcel_rag_chain = (
            {"context": itemgetter("query") | hf_retriever, "query": itemgetter("query")}
            | rag_prompt | hf_llm
        )
        cl.user_session.set("lcel_rag_chain", lcel_rag_chain)
    except KeyError as e:
        print(f"Session error on start: {e}")

@cl.on_message  
async def main(message: cl.Message):
    try:
        lcel_rag_chain = cl.user_session.get("lcel_rag_chain")
        if lcel_rag_chain is None:
            await cl.Message(content="Session has expired. Please restart the chat.").send()
            return

        msg = cl.Message(content="")
        async for chunk in lcel_rag_chain.astream(
            {"query": message.content},
            config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
        ):
            await msg.stream_token(chunk)
        await msg.send()
    except KeyError as e:
        await cl.Message(content="An error occurred. Please restart the chat.").send()
        print(f"Session error: {e}")