Cheselle commited on
Commit
e7f1639
·
verified ·
1 Parent(s): c52349b

Create app_v1.py

Browse files
Files changed (1) hide show
  1. app_v1.py +113 -0
app_v1.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_openai import ChatOpenAI
2
+ from langchain.prompts import ChatPromptTemplate
3
+ from langchain.schema import StrOutputParser
4
+ from langchain.schema.runnable import Runnable
5
+ from langchain.schema.runnable.config import RunnableConfig
6
+ from typing import cast
7
+ from dotenv import load_dotenv
8
+ import os
9
+ from langchain_community.document_loaders import PyMuPDFLoader
10
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
11
+ from langchain_openai.embeddings import OpenAIEmbeddings
12
+ from langchain_community.vectorstores import Qdrant
13
+ from langchain_core.runnables import RunnablePassthrough, RunnableParallel
14
+ import chainlit as cl
15
+ from pathlib import Path
16
+ from sentence_transformers import SentenceTransformer # Import the SentenceTransformer for embeddings
17
+
18
+ load_dotenv()
19
+
20
+ os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
21
+
22
+ @cl.on_chat_start
23
+ async def on_chat_start():
24
+ model = ChatOpenAI(streaming=True)
25
+
26
+ # Load documents
27
+ ai_framework_document = PyMuPDFLoader(file_path="https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf").load()
28
+ ai_blueprint_document = PyMuPDFLoader(file_path="https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf").load()
29
+
30
+ RAG_PROMPT = """\
31
+ Given a provided context and question, you must answer the question based only on context.
32
+
33
+ Context: {context}
34
+ Question: {question}
35
+ """
36
+
37
+ rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
38
+
39
+ sentence_text_splitter = RecursiveCharacterTextSplitter(
40
+ chunk_size=500,
41
+ chunk_overlap=100,
42
+ separators=["\n\n", "\n", ".", "!", "?"]
43
+ )
44
+
45
+ def metadata_generator(document, name, splitter):
46
+ collection = splitter.split_documents(document)
47
+ for doc in collection:
48
+ doc.metadata["source"] = name
49
+ return collection
50
+
51
+ sentence_framework = metadata_generator(ai_framework_document, "AI Framework", sentence_text_splitter)
52
+ sentence_blueprint = metadata_generator(ai_blueprint_document, "AI Blueprint", sentence_text_splitter)
53
+
54
+ sentence_combined_documents = sentence_framework + sentence_blueprint
55
+
56
+ # Initialize the SentenceTransformer embedding model
57
+ embedding_model = SentenceTransformer('Cheselle/finetuned-arctic-sentence')
58
+
59
+ # Create the Qdrant vector store using the embedding model
60
+ sentence_vectorstore = Qdrant.from_documents(
61
+ documents=sentence_combined_documents,
62
+ embedding=embedding_model,
63
+ location=":memory:",
64
+ collection_name="AI Policy"
65
+ )
66
+
67
+ sentence_retriever = sentence_vectorstore.as_retriever()
68
+
69
+ # Set the retriever and prompt into session for reuse
70
+ cl.user_session.set("runnable", model)
71
+ cl.user_session.set("retriever", sentence_retriever)
72
+ cl.user_session.set("prompt_template", rag_prompt)
73
+
74
+
75
+ @cl.on_message
76
+ async def on_message(message: cl.Message):
77
+ # Get the stored model, retriever, and prompt
78
+ model = cast(ChatOpenAI, cl.user_session.get("runnable"))
79
+ retriever = cl.user_session.get("retriever")
80
+ prompt_template = cl.user_session.get("prompt_template")
81
+
82
+ # Log the message content
83
+ print(f"Received message: {message.content}")
84
+
85
+ # Retrieve relevant context from documents based on the user's message
86
+ relevant_docs = retriever.get_relevant_documents(message.content)
87
+ print(f"Retrieved {len(relevant_docs)} documents.")
88
+
89
+ if not relevant_docs:
90
+ print("No relevant documents found.")
91
+ await cl.Message(content="Sorry, I couldn't find any relevant documents.").send()
92
+ return
93
+
94
+ context = "\n\n".join([doc.page_content for doc in relevant_docs])
95
+
96
+ # Log the context to check
97
+ print(f"Context: {context}")
98
+
99
+ # Construct the final RAG prompt
100
+ final_prompt = prompt_template.format(context=context, question=message.content)
101
+ print(f"Final prompt: {final_prompt}")
102
+
103
+ # Initialize a streaming message
104
+ msg = cl.Message(content="")
105
+
106
+ # Stream the response from the model
107
+ async for chunk in model.astream(
108
+ final_prompt,
109
+ config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
110
+ ):
111
+ await msg.stream_token(chunk.content)
112
+
113
+ await msg.send()