Spaces:
Runtime error
Runtime error
# You can find this code for Chainlit python streaming here (https://docs.chainlit.io/concepts/streaming/python) | |
import os | |
from openai import AsyncOpenAI # importing openai for API usage | |
import chainlit as cl # importing chainlit for our app | |
from chainlit.prompt import Prompt, PromptMessage # importing prompt tools | |
#from chainlit.playground.providers import ChatOpenAI # importing ChatOpenAI tools | |
from dotenv import load_dotenv | |
from chainlit.types import AskFileResponse | |
import asyncio | |
from langchain.document_loaders import PyMuPDFLoader, PyPDFLoader | |
from langchain_openai import ChatOpenAI | |
from langchain_core.prompts import ChatPromptTemplate | |
import tiktoken | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_openai.embeddings import OpenAIEmbeddings | |
from utils.custom_retriver import CustomQDrant, CustomVectorStoreRetriever | |
load_dotenv() | |
RAG_PROMPT = """ | |
CONTEXT: | |
{context} | |
QUERY: | |
{question} | |
Answer questions only based on provided context and not your previous knowledge. | |
In your answer never mention phrases like Based on provided context, From the context etc. | |
If you don't know the answer say I don't know! | |
""" | |
data_path = "data/airbnb_midterm.pdf" | |
docs = PyMuPDFLoader(data_path).load() | |
openai_chat_model = ChatOpenAI(model="gpt-3.5-turbo") #gpt-4o | |
def tiktoken_len(text): | |
tokens = tiktoken.encoding_for_model("gpt-4o").encode( | |
text, | |
) | |
return len(tokens) | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size = 500, | |
chunk_overlap = 10, | |
length_function = tiktoken_len, | |
) | |
split_chunks = text_splitter.split_documents(docs) | |
rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT) | |
embedding_model = OpenAIEmbeddings(model="text-embedding-3-small") | |
qdrant_vectorstore = CustomQDrant.from_documents( | |
split_chunks, | |
embedding_model, | |
location=":memory:", | |
collection_name="air bnb data", | |
score_threshold=0.3 | |
) | |
qdrant_retriever = qdrant_vectorstore.as_retriever() | |
from operator import itemgetter | |
from langchain.schema.output_parser import StrOutputParser | |
from langchain.schema.runnable import RunnablePassthrough | |
retrieval_augmented_qa_chain = ( | |
# INVOKE CHAIN WITH: {"question" : "<<SOME USER QUESTION>>"} | |
# "question" : populated by getting the value of the "question" key | |
# "context" : populated by getting the value of the "question" key and chaining it into the base_retriever | |
{"context": itemgetter("question") | qdrant_retriever, "question": itemgetter("question")} | |
# "context" : is assigned to a RunnablePassthrough object (will not be called or considered in the next step) | |
# by getting the value of the "context" key from the previous step | |
| RunnablePassthrough.assign(context=itemgetter("context")) | |
# "response" : the "context" and "question" values are used to format our prompt object and then piped | |
# into the LLM and stored in a key called "response" | |
# "context" : populated by getting the value of the "context" key from the previous step | |
| {"response": rag_prompt | openai_chat_model, "context": itemgetter("context")} | |
) | |
def rename(orig_author: str): | |
rename_dict = {"User": "You", "Chatbot": "Airbnb"} | |
return rename_dict.get(orig_author, orig_author) | |
# marks a function that will be executed at the start of a user session | |
async def start_chat(): | |
cl.user_session.set("chain", retrieval_augmented_qa_chain) | |
# marks a function that should be run each time the chatbot receives a message from a user | |
async def main(message: cl.Message): | |
chain = cl.user_session.get("chain") | |
resp = chain.invoke({"question" : message.content}) | |
source_documents = resp["context"] | |
text_elements = [] # type: List[cl.Text] | |
resp_msg = resp["response"].content | |
if source_documents: | |
for source_idx, source_doc in enumerate(source_documents): | |
source_name = f"source_{source_idx}" | |
# Create the text element referenced in the message | |
text_elements.append( | |
cl.Text(content=source_doc[0].page_content, name="{} (scr: {})".format(source_name, round(source_doc[1],2)), display="side") | |
) | |
source_names = [text_el.name for text_el in text_elements] | |
if source_names: | |
resp_msg += f"\nSources: {', '.join(source_names)}" | |
else: | |
resp_msg += "\nNo sources found" | |
msg = cl.Message(content=resp_msg, elements=text_elements) | |
await msg.send() | |