Mr-Cool commited on
Commit
ddca922
·
verified ·
1 Parent(s): 9f175f1

changed to gpt-4o-mini model

Browse files
Files changed (1) hide show
  1. functions.py +96 -96
functions.py CHANGED
@@ -1,97 +1,97 @@
1
- from langchain_community.document_loaders import PyMuPDFLoader
2
- from langchain_text_splitters import RecursiveCharacterTextSplitter
3
- from langchain_openai import AzureOpenAIEmbeddings, AzureChatOpenAI
4
- from operator import itemgetter
5
- from langchain_core.runnables import RunnablePassthrough
6
- from langchain_qdrant import QdrantVectorStore
7
- from qdrant_client import QdrantClient
8
- from qdrant_client.http.models import Distance, VectorParams
9
- from langchain.prompts import ChatPromptTemplate
10
- import tiktoken
11
- import os
12
-
13
- ### SETUP FUNCTIONS ###
14
- def tiktoken_len(text):
15
- tokens = tiktoken.encoding_for_model("gpt-4o").encode(
16
- text,
17
- )
18
- return len(tokens)
19
-
20
- def setup_vector_db():
21
-
22
- # Get the directory of the current file
23
- current_file_directory = os.path.dirname(os.path.abspath(__file__))
24
- # Change the working directory to the current file's directory
25
- os.chdir(current_file_directory)
26
-
27
- # Load the NIST AI document
28
- PDF_LINK = "data/nist_ai.pdf"
29
- loader = PyMuPDFLoader(file_path=PDF_LINK)
30
- nist_doc = loader.load()
31
-
32
- text_splitter = RecursiveCharacterTextSplitter(
33
- chunk_size = 500,
34
- chunk_overlap = 100,
35
- length_function = tiktoken_len,
36
- )
37
-
38
- nist_chunks = text_splitter.split_documents(nist_doc)
39
-
40
- embeddings_small = AzureOpenAIEmbeddings(azure_deployment="text-embedding-3-small")
41
-
42
- qdrant_client = QdrantClient(":memory:") # set Qdrant DB and its location (in-memory)
43
-
44
- qdrant_client.create_collection(
45
- collection_name="NIST_AI",
46
- vectors_config=VectorParams(size=1536, distance=Distance.COSINE),
47
- )
48
-
49
- qdrant_vector_store = QdrantVectorStore(
50
- client=qdrant_client,
51
- collection_name="NIST_AI",
52
- embedding=embeddings_small,
53
- ) # create a QdrantVectorStore object with the above specified client, collection name, and embedding model.
54
-
55
- qdrant_vector_store.add_documents(nist_chunks) # add the documents to the QdrantVectorStore
56
-
57
- retriever = qdrant_vector_store.as_retriever()
58
-
59
- return retriever
60
-
61
- ### VARIABLES ###
62
-
63
- # define a global variable to store the retriever object
64
- retriever = setup_vector_db()
65
- qa_gpt4_llm = AzureChatOpenAI(azure_deployment="gpt-4", temperature=0) # GPT-4o model
66
-
67
- # define a template for the RAG model
68
- rag_template = """
69
- You are a helpful assistant that helps users find information and answer their question.
70
- You MUST use ONLY the available context to answer the question.
71
- If necessary information to answer the question cannot be found in the provided context, you MUST "I don't know."
72
-
73
- Question:
74
- {question}
75
-
76
- Context:
77
- {context}
78
- """
79
- # create rag prompt object from the template
80
- prompt = ChatPromptTemplate.from_template(rag_template)
81
-
82
- # update the chain with LLM, prompt, and question variable.
83
- retrieval_augmented_qa_chain = (
84
- {"context": itemgetter("question") | retriever, "question": itemgetter("question")}
85
- | RunnablePassthrough.assign(context=itemgetter("context"))
86
- | {"response": prompt | qa_gpt4_llm, "context": itemgetter("context"), "question": itemgetter("question")}
87
- )
88
-
89
- ### FUNCTIONS ###
90
-
91
-
92
- def get_response(query, history):
93
- """A helper function to get the response from the RAG model and return it to the UI."""
94
-
95
- response = retrieval_augmented_qa_chain.invoke({"question" : query})
96
-
97
  return response["response"].content
 
1
+ from langchain_community.document_loaders import PyMuPDFLoader
2
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
3
+ from langchain_openai import AzureOpenAIEmbeddings, AzureChatOpenAI
4
+ from operator import itemgetter
5
+ from langchain_core.runnables import RunnablePassthrough
6
+ from langchain_qdrant import QdrantVectorStore
7
+ from qdrant_client import QdrantClient
8
+ from qdrant_client.http.models import Distance, VectorParams
9
+ from langchain.prompts import ChatPromptTemplate
10
+ import tiktoken
11
+ import os
12
+
13
+ ### SETUP FUNCTIONS ###
14
+ def tiktoken_len(text):
15
+ tokens = tiktoken.encoding_for_model("gpt-4o").encode(
16
+ text,
17
+ )
18
+ return len(tokens)
19
+
20
+ def setup_vector_db():
21
+
22
+ # Get the directory of the current file
23
+ current_file_directory = os.path.dirname(os.path.abspath(__file__))
24
+ # Change the working directory to the current file's directory
25
+ os.chdir(current_file_directory)
26
+
27
+ # Load the NIST AI document
28
+ PDF_LINK = "data/nist_ai.pdf"
29
+ loader = PyMuPDFLoader(file_path=PDF_LINK)
30
+ nist_doc = loader.load()
31
+
32
+ text_splitter = RecursiveCharacterTextSplitter(
33
+ chunk_size = 500,
34
+ chunk_overlap = 100,
35
+ length_function = tiktoken_len,
36
+ )
37
+
38
+ nist_chunks = text_splitter.split_documents(nist_doc)
39
+
40
+ embeddings_small = AzureOpenAIEmbeddings(azure_deployment="text-embedding-3-small")
41
+
42
+ qdrant_client = QdrantClient(":memory:") # set Qdrant DB and its location (in-memory)
43
+
44
+ qdrant_client.create_collection(
45
+ collection_name="NIST_AI",
46
+ vectors_config=VectorParams(size=1536, distance=Distance.COSINE),
47
+ )
48
+
49
+ qdrant_vector_store = QdrantVectorStore(
50
+ client=qdrant_client,
51
+ collection_name="NIST_AI",
52
+ embedding=embeddings_small,
53
+ ) # create a QdrantVectorStore object with the above specified client, collection name, and embedding model.
54
+
55
+ qdrant_vector_store.add_documents(nist_chunks) # add the documents to the QdrantVectorStore
56
+
57
+ retriever = qdrant_vector_store.as_retriever()
58
+
59
+ return retriever
60
+
61
+ ### VARIABLES ###
62
+
63
+ # define a global variable to store the retriever object
64
+ retriever = setup_vector_db()
65
+ qa_gpt4_llm = AzureChatOpenAI(azure_deployment="gpt-4o-mini", temperature=0) # GPT-4o-mini model
66
+
67
+ # define a template for the RAG model
68
+ rag_template = """
69
+ You are a helpful assistant that helps users find information and answer their question.
70
+ You MUST use ONLY the available context to answer the question.
71
+ If necessary information to answer the question cannot be found in the provided context, you MUST "I don't know."
72
+
73
+ Question:
74
+ {question}
75
+
76
+ Context:
77
+ {context}
78
+ """
79
+ # create rag prompt object from the template
80
+ prompt = ChatPromptTemplate.from_template(rag_template)
81
+
82
+ # update the chain with LLM, prompt, and question variable.
83
+ retrieval_augmented_qa_chain = (
84
+ {"context": itemgetter("question") | retriever, "question": itemgetter("question")}
85
+ | RunnablePassthrough.assign(context=itemgetter("context"))
86
+ | {"response": prompt | qa_gpt4_llm, "context": itemgetter("context"), "question": itemgetter("question")}
87
+ )
88
+
89
+ ### FUNCTIONS ###
90
+
91
+
92
+ def get_response(query, history):
93
+ """A helper function to get the response from the RAG model and return it to the UI."""
94
+
95
+ response = retrieval_augmented_qa_chain.invoke({"question" : query})
96
+
97
  return response["response"].content