File size: 4,862 Bytes
ed4ca74
 
 
5ebf50b
 
 
 
 
ed4ca74
5ebf50b
 
ed4ca74
 
 
5ebf50b
6cca36c
ed4ca74
 
 
5ebf50b
ed4ca74
6cca36c
 
 
 
 
 
 
bc4b9a8
 
 
5ebf50b
 
 
ed4ca74
5ebf50b
 
 
 
 
 
 
 
 
 
 
 
ed4ca74
5ebf50b
ed4ca74
 
 
 
fe32e45
5ebf50b
 
 
 
 
80a9356
5ebf50b
 
27aa3e6
5ebf50b
80a9356
6cca36c
 
4a36e50
6cca36c
5ebf50b
 
 
 
 
 
8133318
5ebf50b
6cca36c
 
 
 
 
5ebf50b
 
 
 
ed4ca74
38298ad
5ebf50b
 
 
 
 
 
ed4ca74
5ebf50b
 
ed4ca74
5ebf50b
6cca36c
 
 
 
 
5ebf50b
 
ed4ca74
5ebf50b
 
 
 
ed4ca74
5ebf50b
ed4ca74
5ebf50b
 
ed4ca74
5ebf50b
 
 
ed4ca74
5ebf50b
 
ed4ca74
5ebf50b
 
 
 
 
 
ed4ca74
5ebf50b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema import StrOutputParser
from langchain.schema.runnable import Runnable
from langchain.schema.runnable.config import RunnableConfig
from typing import cast
from dotenv import load_dotenv
import os
from langchain_community.document_loaders import PyMuPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores import Qdrant
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
import chainlit as cl
from pathlib import Path
from sentence_transformers import SentenceTransformer

load_dotenv()

os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")

class SentenceTransformerEmbedding:
    def __init__(self, model_name):
        self.model = SentenceTransformer(model_name)

    def embed_documents(self, texts):
        return self.model.encode(texts, convert_to_tensor=True).tolist()  # Convert to list for compatibility

    def __call__(self, texts):
        return self.embed_documents(texts)  # Make it callable

@cl.on_chat_start
async def on_chat_start():
    model = ChatOpenAI(streaming=True)

    # Load documents
    ai_framework_document = PyMuPDFLoader(file_path="https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf").load()
    ai_blueprint_document = PyMuPDFLoader(file_path="https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf").load()
    
    RAG_PROMPT = """\
    Given a provided context and question, you must answer the question based only on context.

    Context: {context}
    Question: {question}
    """

    rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)

    sentence_text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=500,
        chunk_overlap=100,
        separators=["\n\n", "\n", ".", "!", "?"]
    )

    def metadata_generator(document, name, splitter):
        collection = splitter.split_documents(document)
        for doc in collection:
            doc.metadata["source"] = name
        return collection

    sentence_framework = metadata_generator(ai_framework_document, "AI Framework", sentence_text_splitter)
    sentence_blueprint = metadata_generator(ai_blueprint_document, "AI Blueprint", sentence_text_splitter)

    sentence_combined_documents = sentence_framework + sentence_blueprint

    # Initialize the custom embedding class
    embedding_model = SentenceTransformerEmbedding('Cheselle/finetuned-arctic-sentence')

    # Create the Qdrant vector store using the custom embedding model
    sentence_vectorstore = Qdrant.from_documents(
        documents=sentence_combined_documents,
        embedding=embedding_model,  # Ensure this is an instance
        location=":memory:",
        collection_name="AI Policy"
    )

    sentence_retriever = sentence_vectorstore.as_retriever()

    # Check if retriever is initialized correctly
    if sentence_retriever is None:
        raise ValueError("Retriever is not initialized correctly.")

    # Set the retriever and prompt into session for reuse
    cl.user_session.set("runnable", model)
    cl.user_session.set("retriever", sentence_retriever)
    cl.user_session.set("prompt_template", rag_prompt)


@cl.on_message
async def on_message(message: cl.Message):
    # Get the stored model, retriever, and prompt
    model = cast(ChatOpenAI, cl.user_session.get("runnable"))
    retriever = cl.user_session.get("retriever")
    prompt_template = cl.user_session.get("prompt_template")

    # Log the message content
    print(f"Received message: {message.content}")

    # Retrieve relevant context from documents based on the user's message
    if retriever is None:
        print("Retriever is not available.")
        await cl.Message(content="Sorry, the retriever is not initialized.").send()
        return

    relevant_docs = retriever.get_relevant_documents(message.content)
    print(f"Retrieved {len(relevant_docs)} documents.")

    if not relevant_docs:
        print("No relevant documents found.")
        await cl.Message(content="Sorry, I couldn't find any relevant documents.").send()
        return

    context = "\n\n".join([doc.page_content for doc in relevant_docs])

    # Log the context to check
    print(f"Context: {context}")

    # Construct the final RAG prompt
    final_prompt = prompt_template.format(context=context, question=message.content)
    print(f"Final prompt: {final_prompt}")

    # Initialize a streaming message
    msg = cl.Message(content="")

    # Stream the response from the model
    async for chunk in model.astream(
        final_prompt,
        config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
    ):
        await msg.stream_token(chunk.content)

    await msg.send()