File size: 4,212 Bytes
e7f1639
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bda9a43
e7f1639
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bda9a43
e7f1639
 
bda9a43
e7f1639
 
bda9a43
e7f1639
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema import StrOutputParser
from langchain.schema.runnable import Runnable
from langchain.schema.runnable.config import RunnableConfig
from typing import cast
from dotenv import load_dotenv
import os
from langchain_community.document_loaders import PyMuPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores import Qdrant
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
import chainlit as cl
from pathlib import Path
from sentence_transformers import SentenceTransformer  # Ensure this import is correct

load_dotenv()

os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")

@cl.on_chat_start
async def on_chat_start():
    model = ChatOpenAI(streaming=True)

    # Load documents
    ai_framework_document = PyMuPDFLoader(file_path="https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf").load()
    ai_blueprint_document = PyMuPDFLoader(file_path="https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf").load()
    
    RAG_PROMPT = """\
    Given a provided context and question, you must answer the question based only on context.

    Context: {context}
    Question: {question}
    """

    rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)

    sentence_text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=500,
        chunk_overlap=100,
        separators=["\n\n", "\n", ".", "!", "?"]
    )

    def metadata_generator(document, name, splitter):
        collection = splitter.split_documents(document)
        for doc in collection:
            doc.metadata["source"] = name
        return collection

    sentence_framework = metadata_generator(ai_framework_document, "AI Framework", sentence_text_splitter)
    sentence_blueprint = metadata_generator(ai_blueprint_document, "AI Blueprint", sentence_text_splitter)

    sentence_combined_documents = sentence_framework + sentence_blueprint

    # Initialize the SentenceTransformer model properly
    embedding_model = SentenceTransformer('Cheselle/finetuned-arctic-sentence')

    # Create the Qdrant vector store using the initialized embedding model
    sentence_vectorstore = Qdrant.from_documents(
        documents=sentence_combined_documents,
        embedding=embedding_model,  # Ensure this is an instance
        location=":memory:",
        collection_name="AI Policy"
    )

    sentence_retriever = sentence_vectorstore.as_retriever()
    
    # Set the retriever and prompt into session for reuse
    cl.user_session.set("runnable", model)
    cl.user_session.set("retriever", sentence_retriever)
    cl.user_session.set("prompt_template", rag_prompt)


@cl.on_message
async def on_message(message: cl.Message):
    # Get the stored model, retriever, and prompt
    model = cast(ChatOpenAI, cl.user_session.get("runnable"))
    retriever = cl.user_session.get("retriever")
    prompt_template = cl.user_session.get("prompt_template")

    # Log the message content
    print(f"Received message: {message.content}")

    # Retrieve relevant context from documents based on the user's message
    relevant_docs = retriever.get_relevant_documents(message.content)
    print(f"Retrieved {len(relevant_docs)} documents.")

    if not relevant_docs:
        print("No relevant documents found.")
        await cl.Message(content="Sorry, I couldn't find any relevant documents.").send()
        return

    context = "\n\n".join([doc.page_content for doc in relevant_docs])

    # Log the context to check
    print(f"Context: {context}")

    # Construct the final RAG prompt
    final_prompt = prompt_template.format(context=context, question=message.content)
    print(f"Final prompt: {final_prompt}")

    # Initialize a streaming message
    msg = cl.Message(content="")

    # Stream the response from the model
    async for chunk in model.astream(
        final_prompt,
        config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
    ):
        await msg.stream_token(chunk.content)

    await msg.send()