Katanna941's picture
QDRANT ADDITION
195e6d9
raw
history blame
4.74 kB
import os, tempfile
from typing import List
from chainlit.types import AskFileResponse
from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader
from aimakerspace.openai_utils.prompts import UserRolePrompt, SystemRolePrompt
from aimakerspace.openai_utils.chatmodel import ChatOpenAI
from aimakerspace.openai_utils.embedding import EmbeddingModel
import chainlit as cl
from PyPDF2 import PdfReader
from qdrant_client import QdrantClient
from qdrant_client.http import models
system_template = "Use the following context to answer a users question. If you cannot find the answer in the context, say you don't know the answer."
system_role_prompt = SystemRolePrompt(system_template)
user_prompt_template = "Context:\n{context}\n\nQuestion:\n{question}"
user_role_prompt = UserRolePrompt(user_prompt_template)
class QdrantVectorStore:
def __init__(self, collection_name="my_collection"):
self.client = QdrantClient(":memory:")
self.collection_name = collection_name
self.embedding_model = EmbeddingModel()
async def abuild_from_list(self, texts: List[str]):
self.client.recreate_collection(
collection_name=self.collection_name,
vectors_config=models.VectorParams(size=1536, distance=models.Distance.COSINE),
)
for i, text in enumerate(texts):
vector = await self.embedding_model.aembed_query(text)
self.client.upsert(
collection_name=self.collection_name,
points=[models.PointStruct(id=i, vector=vector, payload={"text": text})]
)
return self
def search_by_text(self, query: str, k: int = 4):
vector = self.embedding_model.embed_query(query)
results = self.client.search(
collection_name=self.collection_name,
query_vector=vector,
limit=k
)
return [(hit.payload["text"], hit.score) for hit in results]
class RetrievalAugmentedQAPipeline:
def __init__(self, llm: ChatOpenAI(), vector_db_retriever: QdrantVectorStore) -> None:
self.llm = llm
self.vector_db_retriever = vector_db_retriever
async def arun_pipeline(self, user_query: str):
context_list = self.vector_db_retriever.search_by_text(user_query, k=4)
context_prompt = "\n".join([context[0] for context in context_list])
formatted_system_prompt = system_role_prompt.create_message()
formatted_user_prompt = user_role_prompt.create_message(question=user_query, context=context_prompt)
async def generate_response():
async for chunk in self.llm.astream([formatted_system_prompt, formatted_user_prompt]):
yield chunk
return {"response": generate_response(), "context": context_list}
text_splitter = CharacterTextSplitter()
def process_file(file: AskFileResponse):
with tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=file.name) as temp_file:
temp_file.write(file.content)
temp_file_path = temp_file.name
if file.type == "text/plain":
text_loader = TextFileLoader(temp_file_path)
documents = text_loader.load_documents()
elif file.type == "application/pdf":
pdf_reader = PdfReader(temp_file_path)
documents = [page.extract_text() for page in pdf_reader.pages]
else:
raise ValueError(f"Unsupported file type: {file.type}")
texts = text_splitter.split_texts(documents)
os.unlink(temp_file_path)
return texts
@cl.on_chat_start
async def on_chat_start():
files = None
while files == None:
files = await cl.AskFileMessage(
content="Please upload a Text or PDF file to begin!",
accept=["text/plain", "application/pdf"],
max_size_mb=20,
timeout=180,
).send()
file = files[0]
msg = cl.Message(content=f"Processing `{file.name}`...", disable_human_feedback=True)
await msg.send()
texts = process_file(file)
print(f"Processing {len(texts)} text chunks")
vector_db = QdrantVectorStore()
vector_db = await vector_db.abuild_from_list(texts)
chat_openai = ChatOpenAI()
retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(vector_db_retriever=vector_db, llm=chat_openai)
msg.content = f"Processing `{file.name}` done. You can now ask questions!"
await msg.update()
cl.user_session.set("chain", retrieval_augmented_qa_pipeline)
@cl.on_message
async def main(message):
chain = cl.user_session.get("chain")
msg = cl.Message(content="")
result = await chain.arun_pipeline(message.content)
async for stream_resp in result["response"]:
await msg.stream_token(stream_resp)
await msg.send()