File size: 5,032 Bytes
03ab966
3ede494
eceefb4
 
 
 
 
 
 
 
 
 
 
 
 
 
40e55f0
 
 
7ae1795
 
aa92ef4
516ec1c
6ea7ef9
 
516ec1c
 
 
 
 
 
 
 
3ede494
 
 
 
340e058
3ede494
 
340e058
3ede494
 
56a0e0e
3ede494
340e058
3ede494
 
56a0e0e
3ede494
340e058
90a2028
 
 
4115e3a
 
4a679bd
340e058
 
3ede494
 
 
56a0e0e
340e058
 
 
4115e3a
340e058
6ea7ef9
3ede494
 
340e058
4115e3a
340e058
6ea7ef9
3ede494
 
 
40e55f0
4115e3a
6ea7ef9
3ede494
 
40e55f0
4115e3a
3ede494
 
6ea7ef9
3ede494
 
309e834
 
 
 
 
 
 
 
08cc2d6
03ab966
 
 
 
08cc2d6
 
3ede494
 
f43960a
3ede494
f43960a
3ede494
 
4e80daf
08cc2d6
f43960a
4e80daf
40e55f0
 
 
f43960a
 
3ede494
 
 
15fb20f
f43960a
15fb20f
90a2028
f43960a
03ab966
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import openai, os

from langchain.chains import LLMChain, RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import PyPDFLoader, WebBaseLoader
from langchain.document_loaders.blob_loaders.youtube_audio import YoutubeAudioLoader
from langchain.document_loaders.generic import GenericLoader
from langchain.document_loaders.parsers import OpenAIWhisperParser
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.prompts import PromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.vectorstores import MongoDBAtlasVectorSearch

from pymongo import MongoClient

RAG_CHROMA  = "Chroma"
RAG_MONGODB = "MongoDB"

PDF_URL     = "https://arxiv.org/pdf/2303.08774.pdf"
WEB_URL     = "https://openai.com/research/gpt-4"
YOUTUBE_URL = "https://www.youtube.com/watch?v=qdd2GZ0DVgc"

YOUTUBE_DIR = "/data/youtube"
CHROMA_DIR  = "/data/chroma"

MONGODB_ATLAS_CLUSTER_URI = os.environ["MONGODB_ATLAS_CLUSTER_URI"]
MONGODB_DB_NAME           = "langchain_db"
MONGODB_COLLECTION_NAME   = "gpt-4"
MONGODB_INDEX_NAME        = "default"

LLM_CHAIN_PROMPT = PromptTemplate(input_variables = ["question"], template = os.environ["LLM_TEMPLATE"])
RAG_CHAIN_PROMPT = PromptTemplate(input_variables = ["context", "question"], template = os.environ["RAG_TEMPLATE"])

client = MongoClient(MONGODB_ATLAS_CLUSTER_URI)
collection = client[MONGODB_DB_NAME][MONGODB_COLLECTION_NAME]

def document_loading():
    docs = []
    
    # PDF
    loader = PyPDFLoader(PDF_URL)
    docs.extend(loader.load())
    print("### Load PDF")
    
    # Web
    loader = WebBaseLoader(WEB_URL)
    docs.extend(loader.load())
    print("### Load Web")
    
    # YouTube
    #loader = GenericLoader(YoutubeAudioLoader([YOUTUBE_URL], YOUTUBE_DIR), 
    #                       OpenAIWhisperParser())
    #docs.extend(loader.load())
    print("### Load YouTube")
    
    return docs

def document_splitting(config, docs):
    text_splitter = RecursiveCharacterTextSplitter(chunk_overlap = config["chunk_overlap"],
                                                   chunk_size = config["chunk_size"])
    
    print("### Split")
    return text_splitter.split_documents(docs)
    
def document_storage_chroma(chunks):
    print("### Store Chroma")
    Chroma.from_documents(documents = chunks, 
                          embedding = OpenAIEmbeddings(disallowed_special = ()), 
                          persist_directory = CHROMA_DIR)

def document_storage_mongodb(chunks):
    print("### Store MongoDB")
    MongoDBAtlasVectorSearch.from_documents(documents = chunks,
                                            embedding = OpenAIEmbeddings(disallowed_special = ()),
                                            collection = collection,
                                            index_name = MONGODB_INDEX_NAME)

def document_retrieval_chroma():
    print("### Retrieve Chroma")
    return Chroma(embedding_function = OpenAIEmbeddings(disallowed_special = ()),
                  persist_directory = CHROMA_DIR)

def document_retrieval_mongodb():
    print("### Retrieve MongoDB")
    return MongoDBAtlasVectorSearch.from_connection_string(MONGODB_ATLAS_CLUSTER_URI,
                                                           MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
                                                           OpenAIEmbeddings(disallowed_special = ()),
                                                           index_name = MONGODB_INDEX_NAME)

def rag_batch(config):
    docs = document_loading()
    
    chunks = document_splitting(config, docs)
    
    document_storage_chroma(chunks)
    document_storage_mongodb(chunks)

def get_llm(config, openai_api_key):
    return ChatOpenAI(model_name = config["model_name"], 
                      openai_api_key = openai_api_key, 
                      temperature = config["temperature"])

def llm_chain(config, openai_api_key, prompt):
    llm_chain = LLMChain(llm = get_llm(config, openai_api_key), 
                         prompt = LLM_CHAIN_PROMPT, 
                         verbose = False)
    
    completion = llm_chain.generate([{"question": prompt}])
    
    return completion, llm_chain

def rag_chain(config, openai_api_key, rag_option, prompt):
    llm = get_llm(config, openai_api_key)
    
    if (rag_option == RAG_CHROMA):
        db = document_retrieval_chroma()
    elif (rag_option == RAG_MONGODB):
        db = document_retrieval_mongodb()
    
    rag_chain = RetrievalQA.from_chain_type(llm, 
                                            chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT}, 
                                            retriever = db.as_retriever(search_kwargs = {"k": config["k"]}), 
                                            return_source_documents = True,
                                            verbose = True)
    
    completion = rag_chain({"query": prompt})
    print(completion)
    
    return completion, rag_chain