File size: 4,987 Bytes
03ab966
3ede494
478d345
eceefb4
 
 
 
 
 
 
 
 
 
 
 
 
 
40e55f0
 
 
dcc2644
 
 
 
516ec1c
6514d80
 
516ec1c
 
 
 
 
 
 
 
3ede494
 
 
 
c8a9d42
3ede494
 
340e058
3ede494
 
 
340e058
3ede494
 
 
340e058
dcc2644
4d86a48
 
4115e3a
4a679bd
340e058
c8a9d42
3ede494
 
 
340e058
 
c8a9d42
340e058
c8efcca
3ede494
 
c8a9d42
340e058
c8efcca
3ede494
 
 
c8a9d42
 
9c86fb0
c8a9d42
9c86fb0
c8a9d42
 
9c86fb0
c8a9d42
ab0af2e
 
3ede494
401a2a7
ab0af2e
 
 
 
401a2a7
08cc2d6
03ab966
 
 
 
c8a9d42
08cc2d6
704c818
f43960a
5edb564
b03208e
f43960a
4e625ab
3ede494
c8a9d42
08cc2d6
9ec3206
4e80daf
c8a9d42
40e55f0
c8a9d42
96012de
f43960a
056f437
 
704c818
 
f43960a
478d345
cde25a7
 
4e625ab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import openai, os

from langchain.callbacks import get_openai_callback
from langchain.chains import LLMChain, RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import PyPDFLoader, WebBaseLoader
from langchain.document_loaders.blob_loaders.youtube_audio import YoutubeAudioLoader
from langchain.document_loaders.generic import GenericLoader
from langchain.document_loaders.parsers import OpenAIWhisperParser
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.prompts import PromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.vectorstores import MongoDBAtlasVectorSearch

from pymongo import MongoClient

RAG_CHROMA  = "Chroma"
RAG_MONGODB = "MongoDB"

PDF_URL       = "https://arxiv.org/pdf/2303.08774.pdf"
WEB_URL       = "https://openai.com/research/gpt-4"
YOUTUBE_URL_1 = "https://www.youtube.com/watch?v=--khbXchTeE"
YOUTUBE_URL_2 = "https://www.youtube.com/watch?v=hdhZwyf24mE"

YOUTUBE_DIR = "/data/yt"
CHROMA_DIR  = "/data/db"

MONGODB_ATLAS_CLUSTER_URI = os.environ["MONGODB_ATLAS_CLUSTER_URI"]
MONGODB_DB_NAME           = "langchain_db"
MONGODB_COLLECTION_NAME   = "gpt-4"
MONGODB_INDEX_NAME        = "default"

LLM_CHAIN_PROMPT = PromptTemplate(input_variables = ["question"], template = os.environ["LLM_TEMPLATE"])
RAG_CHAIN_PROMPT = PromptTemplate(input_variables = ["context", "question"], template = os.environ["RAG_TEMPLATE"])

client = MongoClient(MONGODB_ATLAS_CLUSTER_URI)
collection = client[MONGODB_DB_NAME][MONGODB_COLLECTION_NAME]

def load_documents():
    docs = []
    
    # PDF
    loader = PyPDFLoader(PDF_URL)
    docs.extend(loader.load())
    
    # Web
    loader = WebBaseLoader(WEB_URL)
    docs.extend(loader.load())
    
    # YouTube
    loader = GenericLoader(YoutubeAudioLoader([YOUTUBE_URL_1, YOUTUBE_URL_2], YOUTUBE_DIR), 
                           OpenAIWhisperParser())
    docs.extend(loader.load())
    
    return docs

def split_documents(config, docs):
    text_splitter = RecursiveCharacterTextSplitter(chunk_overlap = config["chunk_overlap"],
                                                   chunk_size = config["chunk_size"])
    
    return text_splitter.split_documents(docs)
    
def embed_store_documents_chroma(chunks):
    Chroma.from_documents(documents = chunks, 
                          embedding = OpenAIEmbeddings(disallowed_special = ()), 
                          persist_directory = CHROMA_DIR)

def embed_store_documents_mongodb(chunks):
    MongoDBAtlasVectorSearch.from_documents(documents = chunks,
                                            embedding = OpenAIEmbeddings(disallowed_special = ()),
                                            collection = collection,
                                            index_name = MONGODB_INDEX_NAME)

def run_rag_batch(config):
    docs = load_documents()
    
    chunks = split_documents(config, docs)
    
    embed_store_documents_chroma(chunks)
    embed_store_documents_mongodb(chunks)

def retrieve_documents_chroma():
    return Chroma(embedding_function = OpenAIEmbeddings(disallowed_special = ()),
                  persist_directory = CHROMA_DIR)

def retrieve_documents_mongodb():
    return MongoDBAtlasVectorSearch.from_connection_string(MONGODB_ATLAS_CLUSTER_URI,
                                                           MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
                                                           OpenAIEmbeddings(disallowed_special = ()),
                                                           index_name = MONGODB_INDEX_NAME)
    
def get_llm(config, openai_api_key):
    return ChatOpenAI(model_name = config["model_name"], 
                      openai_api_key = openai_api_key, 
                      temperature = config["temperature"])

def run_llm_chain(config, openai_api_key, prompt):
    llm_chain = LLMChain(llm = get_llm(config, openai_api_key), 
                         prompt = LLM_CHAIN_PROMPT)
    
    with get_openai_callback() as cb:
        completion = llm_chain.generate([{"question": prompt}])
    
    return completion, llm_chain, cb

def run_rag_chain(config, openai_api_key, rag_option, prompt):
    llm = get_llm(config, openai_api_key)

    if (rag_option == RAG_CHROMA):
        db = retrieve_documents_chroma()
    elif (rag_option == RAG_MONGODB):
        db = retrieve_documents_mongodb()

    rag_chain = RetrievalQA.from_chain_type(llm, 
                                            chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT,
                                                                 "verbose": True}, 
                                            retriever = db.as_retriever(search_kwargs = {"k": config["k"]}), 
                                            return_source_documents = True)
    
    with get_openai_callback() as cb:
        completion = rag_chain({"query": prompt})

    return completion, rag_chain, cb