File size: 3,238 Bytes
7b7f6c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2278eb
 
 
7b7f6c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2278eb
7b7f6c4
 
 
 
 
 
 
 
 
 
 
f2278eb
7b7f6c4
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106

# import libraries
import os
import openai
from langchain_community.document_loaders import PyMuPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.prompts import ChatPromptTemplate
from operator import itemgetter
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough


LLM_MODEL_NAME = "gpt-4-turbo"

# LLM_MODEL_NAME = "gpt-3.5-turbo"


# load PDF doc and convert to text
def load_pdf_to_text(pdf_path):
    # create a document loader
    loader = PyMuPDFLoader(pdf_path)
    # load the document
    doc = loader.load()
    return doc

def split_text(text):
    # create a text splitter
    splitter = RecursiveCharacterTextSplitter(
        chunk_size=700,
        chunk_overlap=100,
    )
    # split the text
    split_text = splitter.split_documents(text)
    return split_text

# load text into FAISS index
def load_text_to_index(doc_splits):
    embeddings = OpenAIEmbeddings(
        model = "text-embedding-3-small"
    )
    vector_store = FAISS.from_documents(doc_splits, embeddings)
    retriever = vector_store.as_retriever()
    return retriever

# query FAISS index
def query_index(retriever, query):
    retrieved_docs = retriever.invoke(query)
    return retrieved_docs

# create answer prompt
def create_answer_prompt():
    template = """Answer the question based only on the following context. If you cannot answer the question with the context, please respond with 'I don't know':

    Context: 
    {context}
    
    Question: 
    {question}
    """
    print("template: ", len(template))
    prompt = ChatPromptTemplate.from_template(template)
    return prompt

# generate answer
def generate_answer(retriever, answer_prompt, query):
    print("generate_answer()")
    QnA_LLM = ChatOpenAI(model_name=LLM_MODEL_NAME, temperature=0.0)

    retrieval_qna_chain = (
        {"context": itemgetter("question") | retriever, "question": itemgetter("question")}
        | RunnablePassthrough.assign(context = itemgetter("context"))
        | {"response": answer_prompt | QnA_LLM, "context": itemgetter("context")}
    )
    result = retrieval_qna_chain.invoke({"question": query})
    return result

def initialize_index():
    # load pdf
    cwd = os.path.abspath(os.getcwd())
    data_dir = "data"
    pdf_file = "Samsara_AG.pdf"
    pdf_path = os.path.join(cwd, data_dir, pdf_file)
    print("path: ", pdf_path)
    doc = load_pdf_to_text(pdf_path)
    print("doc: \n", len(doc))
    doc_splits = split_text(doc)
    print("doc_splits length: \n", len(doc_splits))
    retriever = load_text_to_index(doc_splits)
    return retriever    

def main():
    retriever = initialize_index()
    query = "how to build the best product?"
    retrieved_docs = query_index(retriever, query)
    print("retrieved_docs: \n", len(retrieved_docs))
    answer_prompt = create_answer_prompt()
    print("answer_prompt: \n", answer_prompt)
    result = generate_answer(retriever, answer_prompt, query)
    print("result: \n", result["response"].content)

if __name__ == "__main__":
    main()