midterm-airbnb / app.py
mpav's picture
my midterm1
b36ae4a
raw
history blame
4.54 kB
# You can find this code for Chainlit python streaming here (https://docs.chainlit.io/concepts/streaming/python)
import os
from openai import AsyncOpenAI # importing openai for API usage
import chainlit as cl # importing chainlit for our app
from chainlit.prompt import Prompt, PromptMessage # importing prompt tools
#from chainlit.playground.providers import ChatOpenAI # importing ChatOpenAI tools
from dotenv import load_dotenv
from chainlit.types import AskFileResponse
import asyncio
from langchain.document_loaders import PyMuPDFLoader, PyPDFLoader
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
import tiktoken
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai.embeddings import OpenAIEmbeddings
from utils.custom_retriver import CustomQDrant, CustomVectorStoreRetriever
load_dotenv()
RAG_PROMPT = """
CONTEXT:
{context}
QUERY:
{question}
Answer questions only based on provided context and not your previous knowledge.
In your answer never mention phrases like Based on provided context, From the context etc.
If you don't know the answer say I don't know!
"""
data_path = "data/airbnb_midterm.pdf"
docs = PyMuPDFLoader(data_path).load()
openai_chat_model = ChatOpenAI(model="gpt-3.5-turbo") #gpt-4o
def tiktoken_len(text):
tokens = tiktoken.encoding_for_model("gpt-4o").encode(
text,
)
return len(tokens)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size = 500,
chunk_overlap = 10,
length_function = tiktoken_len,
)
split_chunks = text_splitter.split_documents(docs)
rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
qdrant_vectorstore = CustomQDrant.from_documents(
split_chunks,
embedding_model,
location=":memory:",
collection_name="air bnb data",
score_threshold=0.3
)
qdrant_retriever = qdrant_vectorstore.as_retriever()
from operator import itemgetter
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
retrieval_augmented_qa_chain = (
# INVOKE CHAIN WITH: {"question" : "<<SOME USER QUESTION>>"}
# "question" : populated by getting the value of the "question" key
# "context" : populated by getting the value of the "question" key and chaining it into the base_retriever
{"context": itemgetter("question") | qdrant_retriever, "question": itemgetter("question")}
# "context" : is assigned to a RunnablePassthrough object (will not be called or considered in the next step)
# by getting the value of the "context" key from the previous step
| RunnablePassthrough.assign(context=itemgetter("context"))
# "response" : the "context" and "question" values are used to format our prompt object and then piped
# into the LLM and stored in a key called "response"
# "context" : populated by getting the value of the "context" key from the previous step
| {"response": rag_prompt | openai_chat_model, "context": itemgetter("context")}
)
@cl.author_rename
def rename(orig_author: str):
rename_dict = {"User": "You", "Chatbot": "Airbnb"}
return rename_dict.get(orig_author, orig_author)
@cl.on_chat_start # marks a function that will be executed at the start of a user session
async def start_chat():
cl.user_session.set("chain", retrieval_augmented_qa_chain)
@cl.on_message # marks a function that should be run each time the chatbot receives a message from a user
async def main(message: cl.Message):
chain = cl.user_session.get("chain")
resp = chain.invoke({"question" : message.content})
source_documents = resp["context"]
text_elements = [] # type: List[cl.Text]
resp_msg = resp["response"].content
if source_documents:
for source_idx, source_doc in enumerate(source_documents):
source_name = f"source_{source_idx}"
# Create the text element referenced in the message
text_elements.append(
cl.Text(content=source_doc[0].page_content, name="{} (scr: {})".format(source_name, round(source_doc[1],2)), display="side")
)
source_names = [text_el.name for text_el in text_elements]
if source_names:
resp_msg += f"\nSources: {', '.join(source_names)}"
else:
resp_msg += "\nNo sources found"
msg = cl.Message(content=resp_msg, elements=text_elements)
await msg.send()