Spaces:
Build error
Build error
File size: 9,092 Bytes
7d6d701 d693fc5 7d6d701 0a1cd5f f4087b0 55274da f4087b0 55274da f4087b0 a426126 1ad0dcf bf1b617 a627434 7d6d701 8cf750e 4d8a63e 53d588f 4a15de2 38ee3ac 84e076a 95650ef 38ee3ac 84ce8cd 6f02f68 b7d2e54 e38fd6d 33156e9 b610816 cd9c510 6553dbd 55274da 6772176 2db1016 994b8cd d693fc5 bf1b617 86d2f65 7e2b6ca 86d2f65 b1e2693 86d2f65 84e076a 86d2f65 bf1b617 53d588f 86d2f65 bf1b617 53d588f 86d2f65 f5190b5 503e34f 53d588f 503e34f 9549818 33f1a4f 4d8a63e 503e34f 64931b6 e6f4b15 1dcde2f 503e34f 33f1a4f 84e076a 542a800 64931b6 38ee3ac 1dcde2f 33f1a4f 3fb4fb3 f0bcf66 6cb1c29 1dcde2f 6017859 12d440a f0bcf66 6017859 6cb1c29 12d440a 6017859 511bcdc 1c0e451 6017859 511bcdc 6017859 8573a63 511bcdc 043b829 0c788c0 6cb1c29 09d8d95 6cb1c29 86d2f65 ebcdcac 044c0a3 86d2f65 044c0a3 ebcdcac 044c0a3 acf522c 26b6a5b 043b829 12d440a 1283168 9102fcd 95650ef 1283168 f7926b9 d693fc5 53d588f bf1b617 503e34f 1dcde2f 8d60a3f d693fc5 db5f00f 503e34f 1dcde2f 8d60a3f 1283168 12d440a aa112f1 1283168 12d440a c2e6078 37ab520 043b829 3fb4fb3 8d60a3f 7d6d701 bb79bf1 aa626d0 b7d5b27 908ded3 7d6d701 a4da0c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
import gradio as gr
import openai, os, time, wandb
from langchain.chains import LLMChain, RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import PyPDFLoader, WebBaseLoader
from langchain.document_loaders.blob_loaders.youtube_audio import YoutubeAudioLoader
from langchain.document_loaders.generic import GenericLoader
from langchain.document_loaders.parsers import OpenAIWhisperParser
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.prompts import PromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.vectorstores import MongoDBAtlasVectorSearch
from pymongo import MongoClient
from wandb.sdk.data_types.trace_tree import Trace
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())
WANDB_API_KEY = os.environ["WANDB_API_KEY"]
MONGODB_URI = os.environ["MONGODB_ATLAS_CLUSTER_URI"]
client = MongoClient(MONGODB_URI)
MONGODB_DB_NAME = "langchain_db"
MONGODB_COLLECTION_NAME = "gpt-4"
MONGODB_COLLECTION = client[MONGODB_DB_NAME][MONGODB_COLLECTION_NAME]
MONGODB_INDEX_NAME = "default"
description = os.environ["DESCRIPTION"]
config = {
"chunk_overlap": 150,
"chunk_size": 1500,
"k": 3,
"model_name": "gpt-4",
"temperature": 0,
}
template = """If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer as concise as possible. Always say "Thanks for using the 🧠 app - Bernd" at the end of the answer. """
llm_template = "Answer the question at the end. " + template + "Question: {question} Helpful Answer: "
rag_template = "Use the following pieces of context to answer the question at the end. " + template + "{context}. Question: {question} Helpful Answer: "
LLM_CHAIN_PROMPT = PromptTemplate(input_variables = ["question"], template = llm_template)
RAG_CHAIN_PROMPT = PromptTemplate(input_variables = ["context", "question"], template = rag_template)
CHROMA_DIR = "/data/chroma"
YOUTUBE_DIR = "/data/youtube"
PDF_URL = "https://arxiv.org/pdf/2303.08774.pdf"
WEB_URL = "https://openai.com/research/gpt-4"
YOUTUBE_URL_1 = "https://www.youtube.com/watch?v=--khbXchTeE"
YOUTUBE_URL_2 = "https://www.youtube.com/watch?v=hdhZwyf24mE"
YOUTUBE_URL_3 = "https://www.youtube.com/watch?v=vw-KWfKwvTQ"
RAG_OFF = "Off"
RAG_CHROMA = "Chroma"
RAG_MONGODB = "MongoDB"
def document_loading_splitting():
# Document loading
docs = []
# Load PDF
loader = PyPDFLoader(PDF_URL)
docs.extend(loader.load())
# Load Web
loader = WebBaseLoader(WEB_URL)
docs.extend(loader.load())
# Load YouTube
loader = GenericLoader(YoutubeAudioLoader([YOUTUBE_URL_1,
YOUTUBE_URL_2,
YOUTUBE_URL_3], YOUTUBE_DIR),
OpenAIWhisperParser())
docs.extend(loader.load())
# Document splitting
text_splitter = RecursiveCharacterTextSplitter(chunk_overlap = config["chunk_overlap"],
chunk_size = config["chunk_size"])
splits = text_splitter.split_documents(docs)
return splits
def document_storage_chroma(splits):
Chroma.from_documents(documents = splits,
embedding = OpenAIEmbeddings(disallowed_special = ()),
persist_directory = CHROMA_DIR)
def document_storage_mongodb(splits):
MongoDBAtlasVectorSearch.from_documents(documents = splits,
embedding = OpenAIEmbeddings(disallowed_special = ()),
collection = MONGODB_COLLECTION,
index_name = MONGODB_INDEX_NAME)
def document_retrieval_chroma(llm, prompt):
db = Chroma(embedding_function = OpenAIEmbeddings(),
persist_directory = CHROMA_DIR)
return db
def document_retrieval_mongodb(llm, prompt):
db = MongoDBAtlasVectorSearch.from_connection_string(MONGODB_URI,
MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
OpenAIEmbeddings(disallowed_special = ()),
index_name = MONGODB_INDEX_NAME)
return db
def llm_chain(llm, prompt):
llm_chain = LLMChain(llm = llm, prompt = LLM_CHAIN_PROMPT, verbose = False)
completion = llm_chain.generate([{"question": prompt}])
return completion, llm_chain
def rag_chain(llm, prompt, db):
rag_chain = RetrievalQA.from_chain_type(llm,
chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
retriever = db.as_retriever(search_kwargs = {"k": config["k"]}),
return_source_documents = True,
verbose = False)
completion = rag_chain({"query": prompt})
return completion, rag_chain
def wandb_trace(rag_option, prompt, completion, result, chain, err_msg, start_time_ms, end_time_ms):
wandb.init(project = "openai-llm-rag")
trace = Trace(
kind = "chain",
name = "" if (chain == None) else type(chain).__name__,
status_code = "success" if (str(err_msg) == "") else "error",
status_message = str(err_msg),
metadata = {
"chunk_overlap": config["chunk_overlap"] if (str(err_msg) == "" and rag_option != RAG_OFF) else "",
"chunk_size": config["chunk_size"] if (str(err_msg) == "" and rag_option != RAG_OFF) else "",
},
inputs = {"rag_option": rag_option if (str(err_msg) == "") else "",
"prompt": prompt if (str(err_msg) == "") else "",
},
outputs = {"result": result if (str(err_msg) == "") else "",
"completion": str(completion) if (str(err_msg) == "") else "",
},
model_dict = {"llm_client": str(chain.llm.client) if (rag_option == RAG_OFF) else
str(chain.combine_documents_chain.llm_chain.llm.client),
"model_name": config["model_name"],
"temperature": config["temperature"],
"chain_prompt": str(chain.prompt) if (rag_option == RAG_OFF) else
str(chain.combine_documents_chain.llm_chain.prompt),
"retriever": "" if (rag_option == RAG_OFF) else str(chain.retriever),
},
start_time_ms = start_time_ms,
end_time_ms = end_time_ms
)
trace.log("test")
wandb.finish()
def invoke(openai_api_key, rag_option, prompt):
if (openai_api_key == ""):
raise gr.Error("OpenAI API Key is required.")
if (rag_option is None):
raise gr.Error("Retrieval Augmented Generation is required.")
if (prompt == ""):
raise gr.Error("Prompt is required.")
chain = None
completion = ""
result = ""
err_msg = ""
try:
start_time_ms = round(time.time() * 1000)
llm = ChatOpenAI(model_name = config["model_name"],
openai_api_key = openai_api_key,
temperature = config["temperature"])
if (rag_option == RAG_CHROMA):
#splits = document_loading_splitting()
#document_storage_chroma(splits)
db = document_retrieval_chroma(llm, prompt)
completion, chain = rag_chain(llm, prompt, db)
result = completion["result"]
elif (rag_option == RAG_MONGODB):
#splits = document_loading_splitting()
#document_storage_mongodb(splits)
db = document_retrieval_mongodb(llm, prompt)
completion, chain = rag_chain(llm, prompt, db)
result = completion["result"]
else:
completion, chain = llm_chain(llm, prompt)
result = completion.generations[0][0].text if (completion.generations[0] != None and
completion.generations[0][0] != None) else ""
except Exception as e:
err_msg = e
raise gr.Error(e)
finally:
end_time_ms = round(time.time() * 1000)
wandb_trace(rag_option, prompt, completion, result, chain, err_msg, start_time_ms, end_time_ms)
return result
gr.close_all()
demo = gr.Interface(fn=invoke,
inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1),
gr.Radio([RAG_OFF, RAG_CHROMA, RAG_MONGODB], label = "Retrieval Augmented Generation", value = RAG_OFF),
gr.Textbox(label = "Prompt", value = "What are GPT-4's media capabilities in 5 emojis and 1 sentence?", lines = 1)],
outputs = [gr.Textbox(label = "Completion", lines = 1)],
title = "Generative AI - LLM & RAG",
description = description)
demo.launch() |