File size: 3,529 Bytes
c8612a0
234eac0
 
 
c8612a0
234eac0
 
 
c8612a0
234eac0
c8612a0
234eac0
 
c8612a0
234eac0
 
 
 
 
 
 
 
 
c8612a0
234eac0
 
 
 
 
 
 
 
 
 
 
c8612a0
 
 
234eac0
 
c8612a0
 
 
 
 
 
 
 
234eac0
 
c8612a0
234eac0
 
 
 
 
 
 
c8612a0
 
 
234eac0
 
 
 
c8612a0
234eac0
 
c8612a0
234eac0
 
 
 
 
 
c8612a0
234eac0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os, tempfile
from typing import List
from chainlit.types import AskFileResponse
from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader
from aimakerspace.openai_utils.prompts import UserRolePrompt, SystemRolePrompt
from aimakerspace.vectordatabase import VectorDatabase
from aimakerspace.openai_utils.chatmodel import ChatOpenAI
import chainlit as cl
from PyPDF2 import PdfReader

system_template = "Use the following context to answer a users question. If you cannot find the answer in the context, say you don't know the answer."
system_role_prompt = SystemRolePrompt(system_template)

user_prompt_template = "Context:\n{context}\n\nQuestion:\n{question}"
user_role_prompt = UserRolePrompt(user_prompt_template)

class RetrievalAugmentedQAPipeline:
    def __init__(self, llm: ChatOpenAI(), vector_db_retriever: VectorDatabase) -> None:
        self.llm = llm
        self.vector_db_retriever = vector_db_retriever

    async def arun_pipeline(self, user_query: str):
        context_list = self.vector_db_retriever.search_by_text(user_query, k=4)
        context_prompt = "\n".join([context[0] for context in context_list])
        formatted_system_prompt = system_role_prompt.create_message()
        formatted_user_prompt = user_role_prompt.create_message(question=user_query, context=context_prompt)

        async def generate_response():
            async for chunk in self.llm.astream([formatted_system_prompt, formatted_user_prompt]):
                yield chunk

        return {"response": generate_response(), "context": context_list}

text_splitter = CharacterTextSplitter()

def process_file(file: AskFileResponse):
    with tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=file.name) as temp_file:
        temp_file.write(file.content)
        temp_file_path = temp_file.name

    if file.type == "text/plain":
        text_loader = TextFileLoader(temp_file_path)
        documents = text_loader.load_documents()
    elif file.type == "application/pdf":
        pdf_reader = PdfReader(temp_file_path)
        documents = [page.extract_text() for page in pdf_reader.pages]
    else:
        raise ValueError(f"Unsupported file type: {file.type}")

    texts = text_splitter.split_texts(documents)
    os.unlink(temp_file_path)
    return texts

@cl.on_chat_start
async def on_chat_start():
    files = None
    while files == None:
        files = await cl.AskFileMessage(
            content="Please upload a Text or PDF file to begin!",
            accept=["text/plain", "application/pdf"],
            max_size_mb=20,
            timeout=180,
        ).send()

    file = files[0]
    msg = cl.Message(content=f"Processing `{file.name}`...", disable_human_feedback=True)
    await msg.send()

    texts = process_file(file)
    print(f"Processing {len(texts)} text chunks")

    vector_db = VectorDatabase()
    vector_db = await vector_db.abuild_from_list(texts)
    
    chat_openai = ChatOpenAI()
    retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(vector_db_retriever=vector_db, llm=chat_openai)
    
    msg.content = f"Processing `{file.name}` done. You can now ask questions!"
    await msg.update()

    cl.user_session.set("chain", retrieval_augmented_qa_pipeline)

@cl.on_message
async def main(message):
    chain = cl.user_session.get("chain")
    msg = cl.Message(content="")
    result = await chain.arun_pipeline(message.content)

    async for stream_resp in result["response"]:
        await msg.stream_token(stream_resp)

    await msg.send()