Spaces:
Build error
Build error
File size: 5,032 Bytes
03ab966 3ede494 eceefb4 40e55f0 7ae1795 aa92ef4 516ec1c 6ea7ef9 516ec1c 3ede494 340e058 3ede494 340e058 3ede494 56a0e0e 3ede494 340e058 3ede494 56a0e0e 3ede494 340e058 90a2028 4115e3a 4a679bd 340e058 3ede494 56a0e0e 340e058 4115e3a 340e058 6ea7ef9 3ede494 340e058 4115e3a 340e058 6ea7ef9 3ede494 40e55f0 4115e3a 6ea7ef9 3ede494 40e55f0 4115e3a 3ede494 6ea7ef9 3ede494 309e834 08cc2d6 03ab966 08cc2d6 3ede494 f43960a 3ede494 f43960a 3ede494 4e80daf 08cc2d6 f43960a 4e80daf 40e55f0 f43960a 3ede494 15fb20f f43960a 15fb20f 90a2028 f43960a 03ab966 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import openai, os
from langchain.chains import LLMChain, RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import PyPDFLoader, WebBaseLoader
from langchain.document_loaders.blob_loaders.youtube_audio import YoutubeAudioLoader
from langchain.document_loaders.generic import GenericLoader
from langchain.document_loaders.parsers import OpenAIWhisperParser
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.prompts import PromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.vectorstores import MongoDBAtlasVectorSearch
from pymongo import MongoClient
RAG_CHROMA = "Chroma"
RAG_MONGODB = "MongoDB"
PDF_URL = "https://arxiv.org/pdf/2303.08774.pdf"
WEB_URL = "https://openai.com/research/gpt-4"
YOUTUBE_URL = "https://www.youtube.com/watch?v=qdd2GZ0DVgc"
YOUTUBE_DIR = "/data/youtube"
CHROMA_DIR = "/data/chroma"
MONGODB_ATLAS_CLUSTER_URI = os.environ["MONGODB_ATLAS_CLUSTER_URI"]
MONGODB_DB_NAME = "langchain_db"
MONGODB_COLLECTION_NAME = "gpt-4"
MONGODB_INDEX_NAME = "default"
LLM_CHAIN_PROMPT = PromptTemplate(input_variables = ["question"], template = os.environ["LLM_TEMPLATE"])
RAG_CHAIN_PROMPT = PromptTemplate(input_variables = ["context", "question"], template = os.environ["RAG_TEMPLATE"])
client = MongoClient(MONGODB_ATLAS_CLUSTER_URI)
collection = client[MONGODB_DB_NAME][MONGODB_COLLECTION_NAME]
def document_loading():
docs = []
# PDF
loader = PyPDFLoader(PDF_URL)
docs.extend(loader.load())
print("### Load PDF")
# Web
loader = WebBaseLoader(WEB_URL)
docs.extend(loader.load())
print("### Load Web")
# YouTube
#loader = GenericLoader(YoutubeAudioLoader([YOUTUBE_URL], YOUTUBE_DIR),
# OpenAIWhisperParser())
#docs.extend(loader.load())
print("### Load YouTube")
return docs
def document_splitting(config, docs):
text_splitter = RecursiveCharacterTextSplitter(chunk_overlap = config["chunk_overlap"],
chunk_size = config["chunk_size"])
print("### Split")
return text_splitter.split_documents(docs)
def document_storage_chroma(chunks):
print("### Store Chroma")
Chroma.from_documents(documents = chunks,
embedding = OpenAIEmbeddings(disallowed_special = ()),
persist_directory = CHROMA_DIR)
def document_storage_mongodb(chunks):
print("### Store MongoDB")
MongoDBAtlasVectorSearch.from_documents(documents = chunks,
embedding = OpenAIEmbeddings(disallowed_special = ()),
collection = collection,
index_name = MONGODB_INDEX_NAME)
def document_retrieval_chroma():
print("### Retrieve Chroma")
return Chroma(embedding_function = OpenAIEmbeddings(disallowed_special = ()),
persist_directory = CHROMA_DIR)
def document_retrieval_mongodb():
print("### Retrieve MongoDB")
return MongoDBAtlasVectorSearch.from_connection_string(MONGODB_ATLAS_CLUSTER_URI,
MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
OpenAIEmbeddings(disallowed_special = ()),
index_name = MONGODB_INDEX_NAME)
def rag_batch(config):
docs = document_loading()
chunks = document_splitting(config, docs)
document_storage_chroma(chunks)
document_storage_mongodb(chunks)
def get_llm(config, openai_api_key):
return ChatOpenAI(model_name = config["model_name"],
openai_api_key = openai_api_key,
temperature = config["temperature"])
def llm_chain(config, openai_api_key, prompt):
llm_chain = LLMChain(llm = get_llm(config, openai_api_key),
prompt = LLM_CHAIN_PROMPT,
verbose = False)
completion = llm_chain.generate([{"question": prompt}])
return completion, llm_chain
def rag_chain(config, openai_api_key, rag_option, prompt):
llm = get_llm(config, openai_api_key)
if (rag_option == RAG_CHROMA):
db = document_retrieval_chroma()
elif (rag_option == RAG_MONGODB):
db = document_retrieval_mongodb()
rag_chain = RetrievalQA.from_chain_type(llm,
chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
retriever = db.as_retriever(search_kwargs = {"k": config["k"]}),
return_source_documents = True,
verbose = True)
completion = rag_chain({"query": prompt})
print(completion)
return completion, rag_chain |