Spaces:
Build error
Build error
File size: 7,522 Bytes
7d6d701 4fb4308 7d6d701 99bbf81 0a1cd5f f4087b0 55274da f4087b0 a426126 bf1b617 4fb4308 040f529 e02bd6d a627434 7d6d701 55274da 6772176 2db1016 994b8cd fd30064 d693fc5 99bbf81 215ebea 60832a0 99bbf81 bf1b617 86d2f65 99bbf81 86d2f65 99bbf81 86d2f65 7e2b6ca 86d2f65 99bbf81 86d2f65 b1e2693 99bbf81 86d2f65 84e076a 99bbf81 bf1b617 99bbf81 53d588f 86d2f65 99bbf81 53d588f 99bbf81 53d588f 86d2f65 f5190b5 99bbf81 9549818 33f1a4f 99bbf81 503e34f 99bbf81 e6f4b15 1dcde2f 503e34f 33f1a4f 84e076a 542a800 64931b6 38ee3ac 1dcde2f 33f1a4f 86d2f65 ebcdcac 044c0a3 86d2f65 044c0a3 ebcdcac 044c0a3 1e517cc acf522c 26b6a5b ddfaa69 12d440a 1e517cc 1283168 9102fcd 99bbf81 95650ef 1283168 f7926b9 1e517cc d693fc5 53d588f bf1b617 99bbf81 503e34f 1dcde2f ddfaa69 d693fc5 db5f00f 99bbf81 503e34f 1dcde2f ddfaa69 1283168 12d440a 99bbf81 ddfaa69 1a8b52b ddfaa69 1283168 12d440a 4fb4308 c2e6078 37ab520 043b829 99bbf81 4fb4308 8d60a3f 7d6d701 4fb4308 7d6d701 bb79bf1 99bbf81 b7d5b27 908ded3 fd30064 4fb4308 a4da0c1 |
|
import gradio as gr
import openai, os, time
from dotenv import load_dotenv, find_dotenv
from langchain.chains import LLMChain, RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import PyPDFLoader, WebBaseLoader
from langchain.document_loaders.blob_loaders.youtube_audio import YoutubeAudioLoader
from langchain.document_loaders.generic import GenericLoader
from langchain.document_loaders.parsers import OpenAIWhisperParser
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.prompts import PromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.vectorstores import MongoDBAtlasVectorSearch
from pymongo import MongoClient
from rag import llm_chain, rag_chain
from trace import wandb_trace
_ = load_dotenv(find_dotenv())
PDF_URL = "https://arxiv.org/pdf/2303.08774.pdf"
WEB_URL = "https://openai.com/research/gpt-4"
YOUTUBE_URL_1 = "https://www.youtube.com/watch?v=--khbXchTeE"
YOUTUBE_URL_2 = "https://www.youtube.com/watch?v=hdhZwyf24mE"
YOUTUBE_URL_3 = "https://www.youtube.com/watch?v=vw-KWfKwvTQ"
YOUTUBE_DIR = "/data/youtube"
CHROMA_DIR = "/data/chroma"
MONGODB_ATLAS_CLUSTER_URI = os.environ["MONGODB_ATLAS_CLUSTER_URI"]
MONGODB_DB_NAME = "langchain_db"
MONGODB_COLLECTION_NAME = "gpt-4"
MONGODB_INDEX_NAME = "default"
LLM_CHAIN_PROMPT = PromptTemplate(input_variables = ["question"], template = os.environ["LLM_TEMPLATE"])
RAG_CHAIN_PROMPT = PromptTemplate(input_variables = ["context", "question"], template = os.environ["RAG_TEMPLATE"])
RAG_OFF = "Off"
RAG_CHROMA = "Chroma"
RAG_MONGODB = "MongoDB"
client = MongoClient(MONGODB_ATLAS_CLUSTER_URI)
collection = client[MONGODB_DB_NAME][MONGODB_COLLECTION_NAME]
config = {
"chunk_overlap": 150,
"chunk_size": 1500,
"k": 3,
"model_name": "gpt-4-0613",
"temperature": 0,
}
def document_loading_splitting():
# Document loading
docs = []
# Load PDF
loader = PyPDFLoader(PDF_URL)
docs.extend(loader.load())
# Load Web
loader = WebBaseLoader(WEB_URL)
docs.extend(loader.load())
# Load YouTube
loader = GenericLoader(YoutubeAudioLoader([YOUTUBE_URL_1,
YOUTUBE_URL_2,
YOUTUBE_URL_3], YOUTUBE_DIR),
OpenAIWhisperParser())
docs.extend(loader.load())
# Document splitting
text_splitter = RecursiveCharacterTextSplitter(chunk_overlap = config["chunk_overlap"],
chunk_size = config["chunk_size"])
split_documents = text_splitter.split_documents(docs)
return split_documents
def document_storage_chroma(documents):
Chroma.from_documents(documents = documents,
embedding = OpenAIEmbeddings(disallowed_special = ()),
persist_directory = CHROMA_DIR)
def document_storage_mongodb(documents):
MongoDBAtlasVectorSearch.from_documents(documents = documents,
embedding = OpenAIEmbeddings(disallowed_special = ()),
collection = collection,
index_name = MONGODB_INDEX_NAME)
def document_retrieval_chroma(llm, prompt):
return Chroma(embedding_function = OpenAIEmbeddings(),
persist_directory = CHROMA_DIR)
def document_retrieval_mongodb(llm, prompt):
return MongoDBAtlasVectorSearch.from_connection_string(MONGODB_ATLAS_CLUSTER_URI,
MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
OpenAIEmbeddings(disallowed_special = ()),
index_name = MONGODB_INDEX_NAME)
def llm_chain(llm, prompt):
llm_chain = LLMChain(llm = llm,
prompt = LLM_CHAIN_PROMPT,
verbose = False)
completion = llm_chain.generate([{"question": prompt}])
return completion, llm_chain
def rag_chain(llm, prompt, db):
rag_chain = RetrievalQA.from_chain_type(llm,
chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
retriever = db.as_retriever(search_kwargs = {"k": config["k"]}),
return_source_documents = True,
verbose = False)
completion = rag_chain({"query": prompt})
return completion, rag_chain
def invoke(openai_api_key, rag_option, prompt):
if (openai_api_key == ""):
raise gr.Error("OpenAI API Key is required.")
if (rag_option is None):
raise gr.Error("Retrieval Augmented Generation is required.")
if (prompt == ""):
raise gr.Error("Prompt is required.")
chain = None
completion = ""
result = ""
generation_info = ""
llm_output = ""
err_msg = ""
try:
start_time_ms = round(time.time() * 1000)
llm = ChatOpenAI(model_name = config["model_name"],
openai_api_key = openai_api_key,
temperature = config["temperature"])
if (rag_option == RAG_CHROMA):
#splits = document_loading_splitting()
#document_storage_chroma(splits)
db = document_retrieval_chroma(llm, prompt)
completion, chain = rag_chain(llm, prompt, db)
result = completion["result"]
elif (rag_option == RAG_MONGODB):
#splits = document_loading_splitting()
#document_storage_mongodb(splits)
db = document_retrieval_mongodb(llm, prompt)
completion, chain = rag_chain(llm, prompt, db)
result = completion["result"]
else:
completion, chain = llm_chain(llm, prompt)
if (completion.generations[0] != None and completion.generations[0][0] != None):
result = completion.generations[0][0].text
generation_info = completion.generations[0][0].generation_info
llm_output = completion.llm_output
except Exception as e:
err_msg = e
raise gr.Error(e)
finally:
end_time_ms = round(time.time() * 1000)
wandb_trace(rag_option,
prompt,
completion,
result,
generation_info,
llm_output,
chain,
err_msg,
start_time_ms,
end_time_ms)
return result
gr.close_all()
demo = gr.Interface(fn=invoke,
inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1),
gr.Radio([RAG_OFF, RAG_CHROMA, RAG_MONGODB], label = "Retrieval Augmented Generation", value = RAG_OFF),
gr.Textbox(label = "Prompt", value = "What are GPT-4's media capabilities in 5 emojis and 1 sentence?", lines = 1),
],
outputs = [gr.Textbox(label = "Completion", lines = 1)],
title = "Generative AI - LLM & RAG",
description = os.environ["DESCRIPTION"])
demo.launch() |