File size: 9,092 Bytes
7d6d701
d693fc5
7d6d701
0a1cd5f
f4087b0
55274da
f4087b0
 
 
55274da
f4087b0
 
 
 
a426126
1ad0dcf
bf1b617
 
a627434
 
7d6d701
 
 
8cf750e
 
4d8a63e
 
 
 
 
53d588f
 
4a15de2
 
38ee3ac
84e076a
 
 
95650ef
38ee3ac
 
 
84ce8cd
6f02f68
b7d2e54
 
e38fd6d
33156e9
 
b610816
cd9c510
 
6553dbd
55274da
6772176
2db1016
 
 
994b8cd
d693fc5
 
 
 
bf1b617
86d2f65
 
 
 
 
 
7e2b6ca
86d2f65
 
b1e2693
 
 
 
 
86d2f65
84e076a
 
86d2f65
bf1b617
 
 
53d588f
 
 
86d2f65
bf1b617
53d588f
 
 
 
86d2f65
f5190b5
503e34f
53d588f
503e34f
9549818
33f1a4f
4d8a63e
 
 
 
503e34f
 
 
64931b6
e6f4b15
1dcde2f
503e34f
 
33f1a4f
 
84e076a
542a800
64931b6
38ee3ac
1dcde2f
33f1a4f
3fb4fb3
f0bcf66
6cb1c29
1dcde2f
6017859
12d440a
 
f0bcf66
6017859
 
6cb1c29
12d440a
6017859
511bcdc
1c0e451
6017859
511bcdc
6017859
 
 
 
 
 
8573a63
511bcdc
043b829
0c788c0
6cb1c29
09d8d95
6cb1c29
 
86d2f65
ebcdcac
044c0a3
86d2f65
044c0a3
ebcdcac
044c0a3
acf522c
26b6a5b
043b829
12d440a
1283168
9102fcd
95650ef
1283168
f7926b9
d693fc5
53d588f
bf1b617
503e34f
1dcde2f
8d60a3f
d693fc5
db5f00f
 
503e34f
1dcde2f
8d60a3f
1283168
12d440a
aa112f1
 
1283168
12d440a
c2e6078
37ab520
043b829
3fb4fb3
8d60a3f
7d6d701
 
 
bb79bf1
 
aa626d0
b7d5b27
908ded3
7d6d701
a4da0c1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import gradio as gr
import openai, os, time, wandb

from langchain.chains import LLMChain, RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import PyPDFLoader, WebBaseLoader
from langchain.document_loaders.blob_loaders.youtube_audio import YoutubeAudioLoader
from langchain.document_loaders.generic import GenericLoader
from langchain.document_loaders.parsers import OpenAIWhisperParser

from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.prompts import PromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.vectorstores import MongoDBAtlasVectorSearch

from pymongo import MongoClient

from wandb.sdk.data_types.trace_tree import Trace

from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())

WANDB_API_KEY = os.environ["WANDB_API_KEY"]

MONGODB_URI = os.environ["MONGODB_ATLAS_CLUSTER_URI"]
client = MongoClient(MONGODB_URI)
MONGODB_DB_NAME = "langchain_db"
MONGODB_COLLECTION_NAME = "gpt-4"
MONGODB_COLLECTION = client[MONGODB_DB_NAME][MONGODB_COLLECTION_NAME]
MONGODB_INDEX_NAME = "default"

description = os.environ["DESCRIPTION"]

config = {
    "chunk_overlap": 150,
    "chunk_size": 1500,
    "k": 3,
    "model_name": "gpt-4",
    "temperature": 0,
}

template = """If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer as concise as possible. Always say "Thanks for using the 🧠 app - Bernd" at the end of the answer. """

llm_template = "Answer the question at the end. " + template + "Question: {question} Helpful Answer: "
rag_template = "Use the following pieces of context to answer the question at the end. " + template + "{context}. Question: {question} Helpful Answer: "

LLM_CHAIN_PROMPT = PromptTemplate(input_variables = ["question"], template = llm_template)
RAG_CHAIN_PROMPT = PromptTemplate(input_variables = ["context", "question"], template = rag_template)

CHROMA_DIR  = "/data/chroma"
YOUTUBE_DIR = "/data/youtube"

PDF_URL       = "https://arxiv.org/pdf/2303.08774.pdf"
WEB_URL       = "https://openai.com/research/gpt-4"
YOUTUBE_URL_1 = "https://www.youtube.com/watch?v=--khbXchTeE"
YOUTUBE_URL_2 = "https://www.youtube.com/watch?v=hdhZwyf24mE"
YOUTUBE_URL_3 = "https://www.youtube.com/watch?v=vw-KWfKwvTQ"

RAG_OFF = "Off"
RAG_CHROMA = "Chroma"
RAG_MONGODB = "MongoDB"

def document_loading_splitting():
    # Document loading
    docs = []
    # Load PDF
    loader = PyPDFLoader(PDF_URL)
    docs.extend(loader.load())
    # Load Web
    loader = WebBaseLoader(WEB_URL)
    docs.extend(loader.load())
    # Load YouTube
    loader = GenericLoader(YoutubeAudioLoader([YOUTUBE_URL_1,
                                               YOUTUBE_URL_2,
                                               YOUTUBE_URL_3], YOUTUBE_DIR), 
                           OpenAIWhisperParser())
    docs.extend(loader.load())
    # Document splitting
    text_splitter = RecursiveCharacterTextSplitter(chunk_overlap = config["chunk_overlap"],
                                                   chunk_size = config["chunk_size"])
    splits = text_splitter.split_documents(docs)
    return splits

def document_storage_chroma(splits):
    Chroma.from_documents(documents = splits, 
                          embedding = OpenAIEmbeddings(disallowed_special = ()), 
                          persist_directory = CHROMA_DIR)

def document_storage_mongodb(splits):
    MongoDBAtlasVectorSearch.from_documents(documents = splits,
                                            embedding = OpenAIEmbeddings(disallowed_special = ()),
                                            collection = MONGODB_COLLECTION,
                                            index_name = MONGODB_INDEX_NAME)

def document_retrieval_chroma(llm, prompt):
    db = Chroma(embedding_function = OpenAIEmbeddings(),
                persist_directory = CHROMA_DIR)
    return db

def document_retrieval_mongodb(llm, prompt):
    db = MongoDBAtlasVectorSearch.from_connection_string(MONGODB_URI,
                                                         MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
                                                         OpenAIEmbeddings(disallowed_special = ()),
                                                         index_name = MONGODB_INDEX_NAME)
    return db

def llm_chain(llm, prompt):
    llm_chain = LLMChain(llm = llm, prompt = LLM_CHAIN_PROMPT, verbose = False)
    completion = llm_chain.generate([{"question": prompt}])
    return completion, llm_chain

def rag_chain(llm, prompt, db):
    rag_chain = RetrievalQA.from_chain_type(llm, 
                                            chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT}, 
                                            retriever = db.as_retriever(search_kwargs = {"k": config["k"]}), 
                                            return_source_documents = True,
                                            verbose = False)
    completion = rag_chain({"query": prompt})
    return completion, rag_chain

def wandb_trace(rag_option, prompt, completion, result, chain, err_msg, start_time_ms, end_time_ms):
    wandb.init(project = "openai-llm-rag")
    trace = Trace(
        kind = "chain",
        name = "" if (chain == None) else type(chain).__name__,
        status_code = "success" if (str(err_msg) == "") else "error",
        status_message = str(err_msg),
        metadata = {
            "chunk_overlap": config["chunk_overlap"] if (str(err_msg) == "" and rag_option != RAG_OFF) else "",
            "chunk_size": config["chunk_size"] if (str(err_msg) == "" and rag_option != RAG_OFF) else "",
        },
        inputs = {"rag_option": rag_option if (str(err_msg) == "") else "",
                  "prompt": prompt if (str(err_msg) == "") else "",
        },
        outputs = {"result": result if (str(err_msg) == "") else "",
                   "completion": str(completion) if (str(err_msg) == "") else "",
        },
        model_dict = {"llm_client": str(chain.llm.client) if (rag_option == RAG_OFF) else 
                                    str(chain.combine_documents_chain.llm_chain.llm.client),
                      "model_name": config["model_name"],
                      "temperature": config["temperature"],
                      "chain_prompt": str(chain.prompt) if (rag_option == RAG_OFF) else 
                                      str(chain.combine_documents_chain.llm_chain.prompt),
                      "retriever": "" if (rag_option == RAG_OFF) else str(chain.retriever),
        },
        start_time_ms = start_time_ms,
        end_time_ms = end_time_ms
    )
    trace.log("test")
    wandb.finish()

def invoke(openai_api_key, rag_option, prompt):
    if (openai_api_key == ""):
        raise gr.Error("OpenAI API Key is required.")
    if (rag_option is None):
        raise gr.Error("Retrieval Augmented Generation is required.")
    if (prompt == ""):
        raise gr.Error("Prompt is required.")
    chain = None
    completion = ""
    result = ""
    err_msg = ""
    try:
        start_time_ms = round(time.time() * 1000)
        llm = ChatOpenAI(model_name = config["model_name"], 
                         openai_api_key = openai_api_key, 
                         temperature = config["temperature"])
        if (rag_option == RAG_CHROMA):
            #splits = document_loading_splitting()
            #document_storage_chroma(splits)
            db = document_retrieval_chroma(llm, prompt)
            completion, chain = rag_chain(llm, prompt, db)
            result = completion["result"]
        elif (rag_option == RAG_MONGODB):
            #splits = document_loading_splitting()
            #document_storage_mongodb(splits)
            db = document_retrieval_mongodb(llm, prompt)
            completion, chain = rag_chain(llm, prompt, db)
            result = completion["result"]
        else:
            completion, chain = llm_chain(llm, prompt)
            result = completion.generations[0][0].text if (completion.generations[0] != None and 
                                                           completion.generations[0][0] != None) else ""
    except Exception as e:
        err_msg = e
        raise gr.Error(e)
    finally:
        end_time_ms = round(time.time() * 1000)
        wandb_trace(rag_option, prompt, completion, result, chain, err_msg, start_time_ms, end_time_ms)
    return result

gr.close_all()
demo = gr.Interface(fn=invoke, 
                    inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1), 
                              gr.Radio([RAG_OFF, RAG_CHROMA, RAG_MONGODB], label = "Retrieval Augmented Generation", value = RAG_OFF),
                              gr.Textbox(label = "Prompt", value = "What are GPT-4's media capabilities in 5 emojis and 1 sentence?", lines = 1)],
                    outputs = [gr.Textbox(label = "Completion", lines = 1)],
                    title = "Generative AI - LLM & RAG",
                    description = description)
demo.launch()