Cheselle commited on
Commit
5ebf50b
·
verified ·
1 Parent(s): 5776c09

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -97
app.py CHANGED
@@ -1,131 +1,113 @@
1
- import re
2
-
3
- from langchain_openai import OpenAIEmbeddings
4
  from langchain_openai import ChatOpenAI
5
- from langchain_openai.embeddings import OpenAIEmbeddings
6
-
7
  from langchain.prompts import ChatPromptTemplate
8
- from langchain.text_splitter import RecursiveCharacterTextSplitter
9
  from langchain.schema import StrOutputParser
10
-
 
 
 
 
11
  from langchain_community.document_loaders import PyMuPDFLoader
 
 
12
  from langchain_community.vectorstores import Qdrant
13
-
14
  from langchain_core.runnables import RunnablePassthrough, RunnableParallel
15
- from langchain_core.documents import Document
16
-
17
- from operator import itemgetter
18
- import os
19
- from dotenv import load_dotenv
20
  import chainlit as cl
 
 
21
 
22
  load_dotenv()
23
 
 
24
 
25
- ai_framework_document = PyMuPDFLoader(file_path="https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf").load()
26
- ai_blueprint_document = PyMuPDFLoader(file_path="https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf").load()
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- def metadata_generator(document, name):
30
- fixed_text_splitter = RecursiveCharacterTextSplitter(
31
  chunk_size=500,
32
  chunk_overlap=100,
33
  separators=["\n\n", "\n", ".", "!", "?"]
34
  )
35
- collection = fixed_text_splitter.split_documents(document)
36
- for doc in collection:
37
- doc.metadata["source"] = name
38
- return collection
39
-
40
- recursive_framework_document = metadata_generator(ai_framework_document, "AI Framework")
41
- recursive_blueprint_document = metadata_generator(ai_blueprint_document, "AI Blueprint")
42
- combined_documents = recursive_framework_document + recursive_blueprint_document
43
-
44
-
45
- #from transformers import AutoTokenizer, AutoModel
46
- #import torch
47
- #embeddings = AutoModel.from_pretrained("Cheselle/finetuned-arctic-sentence")
48
- #tokenizer = AutoTokenizer.from_pretrained("Cheselle/finetuned-arctic-sentence")
49
 
50
- # Assuming ai_framework_document and ai_blueprint_document are lists of langchain_core.documents.Document
51
- ai_framework_text = "".join([doc.page_content for doc in ai_framework_document])
 
 
 
52
 
53
- # Similarly for ai_blueprint_document
54
- ai_blueprint_text = "".join([doc.page_content for doc in ai_blueprint_document])
55
-
56
- # Now you can use these text variables
57
 
 
58
 
 
 
59
 
 
 
 
 
 
 
 
60
 
61
- from sentence_transformers import SentenceTransformer
62
- embedding_model = SentenceTransformer("Cheselle/finetuned-arctic-sentence")
63
- embeddings = embedding_model.encode(ai_framework_text + ai_blueprint_text)
64
- #embeddings = embedding_model.encode(ai_framework_text + ai_blueprint_text)
65
- #embeddings = embedding_model.encode(ai_framework_document + ai_blueprint_document)
 
66
 
67
- vectorstore = Qdrant.from_documents(
68
- documents=combined_documents,
69
- embedding=lambda docs: embedding_model.encode([doc.page_content for doc in docs]),
70
- #embedding=embedding_model,
71
- #embedding=embeddings,
72
- location=":memory:",
73
- collection_name="ai_policy"
74
- )
75
 
76
- retriever = vectorstore.as_retriever()
 
 
 
 
 
77
 
78
- ## Generation LLM
79
- llm = ChatOpenAI(model="gpt-4o-mini")
80
 
81
- RAG_PROMPT = """\
82
- You are an AI Policy Expert.
83
- Given a provided context and question, you must answer the question based only on context.
84
- Think through your answer carefully and step by step.
85
 
86
- Context: {context}
87
- Question: {question}
88
- """
 
89
 
90
- rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
91
 
92
- retrieval_augmented_qa_chain = (
93
- # INVOKE CHAIN WITH: {"question" : "<<SOME USER QUESTION>>"}
94
- # "question" : populated by getting the value of the "question" key
95
- # "context" : populated by getting the value of the "question" key and chaining it into the base_retriever
96
- {"context": itemgetter("question") | retriever, "question": itemgetter("question")}
97
- # "context" : is assigned to a RunnablePassthrough object (will not be called or considered in the next step)
98
- # by getting the value of the "context" key from the previous step
99
- | RunnablePassthrough.assign(context=itemgetter("context"))
100
- # "response" : the "context" and "question" values are used to format our prompt object and then piped
101
- # into the LLM and stored in a key called "response"
102
- # "context" : populated by getting the value of the "context" key from the previous step
103
- | {"response": rag_prompt | llm, "context": itemgetter("context")}
104
- )
105
 
106
- #alt_rag_chain.invoke({"question" : "What is the AI framework all about?"})
 
 
107
 
108
- @cl.on_message
109
- async def handle_message(message):
110
- try:
111
- # Process the incoming question using the RAG chain
112
- result = retrieval_augmented_qa_chain.invoke({"question": message.content})
113
 
114
- # Create a new message for the response
115
- response_message = cl.Message(content=result["response"].content)
 
 
 
 
116
 
117
- # Send the response back to the user
118
- await response_message.send()
119
-
120
- except Exception as e:
121
- # Handle any exception and log it or send a response back to the user
122
- error_message = cl.Message(content=f"An error occurred: {str(e)}")
123
- await error_message.send()
124
- print(f"Error occurred: {e}")
125
-
126
- # Run the ChainLit server
127
- if __name__ == "__main__":
128
- try:
129
- cl.run()
130
- except Exception as e:
131
- print(f"Server error occurred: {e}")
 
 
 
 
1
  from langchain_openai import ChatOpenAI
 
 
2
  from langchain.prompts import ChatPromptTemplate
 
3
  from langchain.schema import StrOutputParser
4
+ from langchain.schema.runnable import Runnable
5
+ from langchain.schema.runnable.config import RunnableConfig
6
+ from typing import cast
7
+ from dotenv import load_dotenv
8
+ import os
9
  from langchain_community.document_loaders import PyMuPDFLoader
10
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
11
+ from langchain_openai.embeddings import OpenAIEmbeddings
12
  from langchain_community.vectorstores import Qdrant
 
13
  from langchain_core.runnables import RunnablePassthrough, RunnableParallel
 
 
 
 
 
14
  import chainlit as cl
15
+ from pathlib import Path
16
+ from sentence_transformers import SentenceTransformer # Ensure this import is correct
17
 
18
  load_dotenv()
19
 
20
+ os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
21
 
22
+ @cl.on_chat_start
23
+ async def on_chat_start():
24
+ model = ChatOpenAI(streaming=True)
25
 
26
+ # Load documents
27
+ ai_framework_document = PyMuPDFLoader(file_path="https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf").load()
28
+ ai_blueprint_document = PyMuPDFLoader(file_path="https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf").load()
29
+
30
+ RAG_PROMPT = """\
31
+ Given a provided context and question, you must answer the question based only on context.
32
+
33
+ Context: {context}
34
+ Question: {question}
35
+ """
36
+
37
+ rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
38
 
39
+ sentence_text_splitter = RecursiveCharacterTextSplitter(
 
40
  chunk_size=500,
41
  chunk_overlap=100,
42
  separators=["\n\n", "\n", ".", "!", "?"]
43
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ def metadata_generator(document, name, splitter):
46
+ collection = splitter.split_documents(document)
47
+ for doc in collection:
48
+ doc.metadata["source"] = name
49
+ return collection
50
 
51
+ sentence_framework = metadata_generator(ai_framework_document, "AI Framework", sentence_text_splitter)
52
+ sentence_blueprint = metadata_generator(ai_blueprint_document, "AI Blueprint", sentence_text_splitter)
 
 
53
 
54
+ sentence_combined_documents = sentence_framework + sentence_blueprint
55
 
56
+ # Initialize the SentenceTransformer model properly
57
+ embedding_model = SentenceTransformer('Cheselle/finetuned-arctic-sentence')
58
 
59
+ # Create the Qdrant vector store using the initialized embedding model
60
+ sentence_vectorstore = Qdrant.from_documents(
61
+ documents=sentence_combined_documents,
62
+ embedding=embedding_model, # Ensure this is an instance
63
+ location=":memory:",
64
+ collection_name="AI Policy"
65
+ )
66
 
67
+ sentence_retriever = sentence_vectorstore.as_retriever()
68
+
69
+ # Set the retriever and prompt into session for reuse
70
+ cl.user_session.set("runnable", model)
71
+ cl.user_session.set("retriever", sentence_retriever)
72
+ cl.user_session.set("prompt_template", rag_prompt)
73
 
 
 
 
 
 
 
 
 
74
 
75
+ @cl.on_message
76
+ async def on_message(message: cl.Message):
77
+ # Get the stored model, retriever, and prompt
78
+ model = cast(ChatOpenAI, cl.user_session.get("runnable"))
79
+ retriever = cl.user_session.get("retriever")
80
+ prompt_template = cl.user_session.get("prompt_template")
81
 
82
+ # Log the message content
83
+ print(f"Received message: {message.content}")
84
 
85
+ # Retrieve relevant context from documents based on the user's message
86
+ relevant_docs = retriever.get_relevant_documents(message.content)
87
+ print(f"Retrieved {len(relevant_docs)} documents.")
 
88
 
89
+ if not relevant_docs:
90
+ print("No relevant documents found.")
91
+ await cl.Message(content="Sorry, I couldn't find any relevant documents.").send()
92
+ return
93
 
94
+ context = "\n\n".join([doc.page_content for doc in relevant_docs])
95
 
96
+ # Log the context to check
97
+ print(f"Context: {context}")
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ # Construct the final RAG prompt
100
+ final_prompt = prompt_template.format(context=context, question=message.content)
101
+ print(f"Final prompt: {final_prompt}")
102
 
103
+ # Initialize a streaming message
104
+ msg = cl.Message(content="")
 
 
 
105
 
106
+ # Stream the response from the model
107
+ async for chunk in model.astream(
108
+ final_prompt,
109
+ config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
110
+ ):
111
+ await msg.stream_token(chunk.content)
112
 
113
+ await msg.send()