File size: 5,005 Bytes
ed4ca74
 
 
5ebf50b
 
 
 
ed4ca74
5ebf50b
ed4ca74
 
6cca36c
ed4ca74
b83d3fb
ed4ca74
5ebf50b
ed4ca74
6cca36c
 
 
 
 
 
 
bc4b9a8
 
 
ffd0213
5ebf50b
 
ed4ca74
5ebf50b
 
 
ffd0213
 
 
5ebf50b
 
 
 
 
 
 
 
ed4ca74
5ebf50b
ed4ca74
 
 
 
fe32e45
5ebf50b
 
 
 
 
80a9356
5ebf50b
 
27aa3e6
5ebf50b
80a9356
ffd0213
 
 
 
 
 
b83d3fb
6cca36c
4a36e50
b83d3fb
5ebf50b
 
ffd0213
5ebf50b
 
 
8133318
ffd0213
 
b83d3fb
5ebf50b
6cca36c
 
 
 
 
5ebf50b
 
 
 
ed4ca74
ffd0213
5ebf50b
 
b83d3fb
5ebf50b
 
ed4ca74
5ebf50b
 
ed4ca74
5ebf50b
6cca36c
 
 
 
 
5ebf50b
 
ed4ca74
5ebf50b
 
 
 
ed4ca74
5ebf50b
ed4ca74
5ebf50b
 
ed4ca74
5ebf50b
 
 
ed4ca74
5ebf50b
 
ed4ca74
5ebf50b
 
 
 
 
 
ed4ca74
5ebf50b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema import StrOutputParser
from langchain.schema.runnable import Runnable
from langchain.schema.runnable.config import RunnableConfig
from dotenv import load_dotenv
import os
from langchain_community.document_loaders import PyMuPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Qdrant
import chainlit as cl
from sentence_transformers import SentenceTransformer

# Load environment variables
load_dotenv()
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")

class SentenceTransformerEmbedding:
    def __init__(self, model_name):
        self.model = SentenceTransformer(model_name)

    def embed_documents(self, texts):
        return self.model.encode(texts, convert_to_tensor=True).tolist()  # Convert to list for compatibility

    def __call__(self, texts):
        return self.embed_documents(texts)  # Make it callable

@cl.on_chat_start
async def on_chat_start():
    model = ChatOpenAI(streaming=True)

    # Load documents
    ai_framework_document = PyMuPDFLoader(file_path="https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf").load()
    ai_blueprint_document = PyMuPDFLoader(file_path="https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf").load()

    print("Documents loaded.")

    RAG_PROMPT = """\
    Given a provided context and question, you must answer the question based only on context.

    Context: {context}
    Question: {question}
    """

    rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)

    sentence_text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=500,
        chunk_overlap=100,
        separators=["\n\n", "\n", ".", "!", "?"]
    )

    def metadata_generator(document, name, splitter):
        collection = splitter.split_documents(document)
        for doc in collection:
            doc.metadata["source"] = name
        return collection

    sentence_framework = metadata_generator(ai_framework_document, "AI Framework", sentence_text_splitter)
    sentence_blueprint = metadata_generator(ai_blueprint_document, "AI Blueprint", sentence_text_splitter)

    sentence_combined_documents = sentence_framework + sentence_blueprint

    print(f"Total documents to embed: {len(sentence_combined_documents)}")

    # Limit the number of documents processed for debugging
    max_documents = 10
    sentence_combined_documents = sentence_combined_documents[:max_documents]

    # Initialize the embedding model instance
    embedding_model = SentenceTransformerEmbedding('Cheselle/finetuned-arctic-sentence')

    # Create the Qdrant vector store using the embedding instance
    sentence_vectorstore = Qdrant.from_documents(
        documents=sentence_combined_documents,
        embedding=embedding_model,
        location=":memory:",
        collection_name="AI Policy"
    )

    print("Vector store created.")

    # Create retriever from the vector store
    sentence_retriever = sentence_vectorstore.as_retriever()

    # Check if retriever is initialized correctly
    if sentence_retriever is None:
        raise ValueError("Retriever is not initialized correctly.")

    # Set the retriever and prompt into session for reuse
    cl.user_session.set("runnable", model)
    cl.user_session.set("retriever", sentence_retriever)
    cl.user_session.set("prompt_template", rag_prompt)

@cl.on_message
async def on_message(message: cl.Message):
    # Get the stored model, retriever, and prompt
    model = cl.user_session.get("runnable")
    retriever = cl.user_session.get("retriever")
    prompt_template = cl.user_session.get("prompt_template")

    # Log the message content
    print(f"Received message: {message.content}")

    # Retrieve relevant context from documents based on the user's message
    if retriever is None:
        print("Retriever is not available.")
        await cl.Message(content="Sorry, the retriever is not initialized.").send()
        return

    relevant_docs = retriever.get_relevant_documents(message.content)
    print(f"Retrieved {len(relevant_docs)} documents.")

    if not relevant_docs:
        print("No relevant documents found.")
        await cl.Message(content="Sorry, I couldn't find any relevant documents.").send()
        return

    context = "\n\n".join([doc.page_content for doc in relevant_docs])

    # Log the context to check
    print(f"Context: {context}")

    # Construct the final RAG prompt
    final_prompt = prompt_template.format(context=context, question=message.content)
    print(f"Final prompt: {final_prompt}")

    # Initialize a streaming message
    msg = cl.Message(content="")

    # Stream the response from the model
    async for chunk in model.astream(
        final_prompt,
        config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
    ):
        await msg.stream_token(chunk.content)

    await msg.send()