File size: 4,212 Bytes
ed4ca74
 
 
5ebf50b
 
 
 
 
ed4ca74
5ebf50b
 
ed4ca74
 
 
5ebf50b
 
ed4ca74
 
 
5ebf50b
ed4ca74
5ebf50b
 
 
ed4ca74
5ebf50b
 
 
 
 
 
 
 
 
 
 
 
ed4ca74
5ebf50b
ed4ca74
 
 
 
fe32e45
5ebf50b
 
 
 
 
80a9356
5ebf50b
 
27aa3e6
5ebf50b
80a9356
5ebf50b
 
4a36e50
5ebf50b
 
 
 
 
 
 
8133318
5ebf50b
 
 
 
 
 
ed4ca74
38298ad
5ebf50b
 
 
 
 
 
ed4ca74
5ebf50b
 
ed4ca74
5ebf50b
 
 
ed4ca74
5ebf50b
 
 
 
ed4ca74
5ebf50b
ed4ca74
5ebf50b
 
ed4ca74
5ebf50b
 
 
ed4ca74
5ebf50b
 
ed4ca74
5ebf50b
 
 
 
 
 
ed4ca74
5ebf50b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema import StrOutputParser
from langchain.schema.runnable import Runnable
from langchain.schema.runnable.config import RunnableConfig
from typing import cast
from dotenv import load_dotenv
import os
from langchain_community.document_loaders import PyMuPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores import Qdrant
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
import chainlit as cl
from pathlib import Path
from sentence_transformers import SentenceTransformer  # Ensure this import is correct

load_dotenv()

os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")

@cl.on_chat_start
async def on_chat_start():
    model = ChatOpenAI(streaming=True)

    # Load documents
    ai_framework_document = PyMuPDFLoader(file_path="https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf").load()
    ai_blueprint_document = PyMuPDFLoader(file_path="https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf").load()
    
    RAG_PROMPT = """\
    Given a provided context and question, you must answer the question based only on context.

    Context: {context}
    Question: {question}
    """

    rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)

    sentence_text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=500,
        chunk_overlap=100,
        separators=["\n\n", "\n", ".", "!", "?"]
    )

    def metadata_generator(document, name, splitter):
        collection = splitter.split_documents(document)
        for doc in collection:
            doc.metadata["source"] = name
        return collection

    sentence_framework = metadata_generator(ai_framework_document, "AI Framework", sentence_text_splitter)
    sentence_blueprint = metadata_generator(ai_blueprint_document, "AI Blueprint", sentence_text_splitter)

    sentence_combined_documents = sentence_framework + sentence_blueprint

    # Initialize the SentenceTransformer model properly
    embedding_model = SentenceTransformer('Cheselle/finetuned-arctic-sentence')

    # Create the Qdrant vector store using the initialized embedding model
    sentence_vectorstore = Qdrant.from_documents(
        documents=sentence_combined_documents,
        embedding=embedding_model,  # Ensure this is an instance
        location=":memory:",
        collection_name="AI Policy"
    )

    sentence_retriever = sentence_vectorstore.as_retriever()
    
    # Set the retriever and prompt into session for reuse
    cl.user_session.set("runnable", model)
    cl.user_session.set("retriever", sentence_retriever)
    cl.user_session.set("prompt_template", rag_prompt)


@cl.on_message
async def on_message(message: cl.Message):
    # Get the stored model, retriever, and prompt
    model = cast(ChatOpenAI, cl.user_session.get("runnable"))
    retriever = cl.user_session.get("retriever")
    prompt_template = cl.user_session.get("prompt_template")

    # Log the message content
    print(f"Received message: {message.content}")

    # Retrieve relevant context from documents based on the user's message
    relevant_docs = retriever.get_relevant_documents(message.content)
    print(f"Retrieved {len(relevant_docs)} documents.")

    if not relevant_docs:
        print("No relevant documents found.")
        await cl.Message(content="Sorry, I couldn't find any relevant documents.").send()
        return

    context = "\n\n".join([doc.page_content for doc in relevant_docs])

    # Log the context to check
    print(f"Context: {context}")

    # Construct the final RAG prompt
    final_prompt = prompt_template.format(context=context, question=message.content)
    print(f"Final prompt: {final_prompt}")

    # Initialize a streaming message
    msg = cl.Message(content="")

    # Stream the response from the model
    async for chunk in model.astream(
        final_prompt,
        config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
    ):
        await msg.stream_token(chunk.content)

    await msg.send()