bstraehle commited on
Commit
9854dd8
·
1 Parent(s): 3cfd6d3

Delete rag.py

Browse files
Files changed (1) hide show
  1. rag.py +0 -142
rag.py DELETED
@@ -1,142 +0,0 @@
1
- import logging, os, sys
2
-
3
- from langchain.callbacks import get_openai_callback
4
- from langchain.chains import LLMChain, RetrievalQA
5
- from langchain.chat_models import ChatOpenAI
6
- from langchain.document_loaders import PyPDFLoader, WebBaseLoader
7
- from langchain.document_loaders.blob_loaders.youtube_audio import YoutubeAudioLoader
8
- from langchain.document_loaders.generic import GenericLoader
9
- from langchain.document_loaders.parsers import OpenAIWhisperParser
10
- from langchain.embeddings.openai import OpenAIEmbeddings
11
- from langchain.prompts import PromptTemplate
12
- from langchain.text_splitter import RecursiveCharacterTextSplitter
13
- from langchain.vectorstores import Chroma
14
- from langchain.vectorstores import MongoDBAtlasVectorSearch
15
-
16
- from pymongo import MongoClient
17
-
18
- RAG_CHROMA = "Chroma"
19
- RAG_MONGODB = "MongoDB"
20
-
21
- PDF_URL = "https://arxiv.org/pdf/2303.08774.pdf"
22
- WEB_URL = "https://openai.com/research/gpt-4"
23
- YOUTUBE_URL_1 = "https://www.youtube.com/watch?v=--khbXchTeE"
24
- YOUTUBE_URL_2 = "https://www.youtube.com/watch?v=hdhZwyf24mE"
25
-
26
- YOUTUBE_DIR = "/data/yt"
27
- CHROMA_DIR = "/data/db"
28
-
29
- MONGODB_ATLAS_CLUSTER_URI = os.environ["MONGODB_ATLAS_CLUSTER_URI"]
30
- MONGODB_DB_NAME = "langchain_db"
31
- MONGODB_COLLECTION_NAME = "gpt-4"
32
- MONGODB_INDEX_NAME = "default"
33
-
34
- LLM_CHAIN_PROMPT = PromptTemplate(
35
- input_variables = ["question"],
36
- template = os.environ["LLM_TEMPLATE"])
37
- RAG_CHAIN_PROMPT = PromptTemplate(
38
- input_variables = ["context", "question"],
39
- template = os.environ["RAG_TEMPLATE"])
40
-
41
- logging.basicConfig(stream = sys.stdout, level = logging.INFO)
42
- logging.getLogger().addHandler(logging.StreamHandler(stream = sys.stdout))
43
-
44
- def load_documents():
45
- docs = []
46
-
47
- # PDF
48
- loader = PyPDFLoader(PDF_URL)
49
- docs.extend(loader.load())
50
- #print("docs = " + str(len(docs)))
51
-
52
- # Web
53
- loader = WebBaseLoader(WEB_URL)
54
- docs.extend(loader.load())
55
- #print("docs = " + str(len(docs)))
56
-
57
- # YouTube
58
- loader = GenericLoader(
59
- YoutubeAudioLoader(
60
- [YOUTUBE_URL_1, YOUTUBE_URL_2],
61
- YOUTUBE_DIR),
62
- OpenAIWhisperParser())
63
- docs.extend(loader.load())
64
- #print("docs = " + str(len(docs)))
65
-
66
- return docs
67
-
68
- def split_documents(config, docs):
69
- text_splitter = RecursiveCharacterTextSplitter()
70
-
71
- return text_splitter.split_documents(docs)
72
-
73
- def store_chroma(chunks):
74
- Chroma.from_documents(
75
- documents = chunks,
76
- embedding = OpenAIEmbeddings(disallowed_special = ()),
77
- persist_directory = CHROMA_DIR)
78
-
79
- def store_mongodb(chunks):
80
- client = MongoClient(MONGODB_ATLAS_CLUSTER_URI)
81
- collection = client[MONGODB_DB_NAME][MONGODB_COLLECTION_NAME]
82
-
83
- MongoDBAtlasVectorSearch.from_documents(
84
- documents = chunks,
85
- embedding = OpenAIEmbeddings(disallowed_special = ()),
86
- collection = collection,
87
- index_name = MONGODB_INDEX_NAME)
88
-
89
- def rag_ingestion(config):
90
- docs = load_documents()
91
-
92
- chunks = split_documents(config, docs)
93
-
94
- store_chroma(chunks)
95
- store_mongodb(chunks)
96
-
97
- def retrieve_chroma():
98
- return Chroma(
99
- embedding_function = OpenAIEmbeddings(disallowed_special = ()),
100
- persist_directory = CHROMA_DIR)
101
-
102
- def retrieve_mongodb():
103
- return MongoDBAtlasVectorSearch.from_connection_string(
104
- MONGODB_ATLAS_CLUSTER_URI,
105
- MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
106
- OpenAIEmbeddings(disallowed_special = ()),
107
- index_name = MONGODB_INDEX_NAME)
108
-
109
- def get_llm(config):
110
- return ChatOpenAI(
111
- model_name = config["model_name"],
112
- temperature = config["temperature"])
113
-
114
- def llm_chain(config, prompt):
115
- llm_chain = LLMChain(
116
- llm = get_llm(config),
117
- prompt = LLM_CHAIN_PROMPT)
118
-
119
- with get_openai_callback() as cb:
120
- completion = llm_chain.generate([{"question": prompt}])
121
-
122
- return completion, llm_chain, cb
123
-
124
- def rag_chain(config, rag_option, prompt):
125
- llm = get_llm(config)
126
-
127
- if (rag_option == RAG_CHROMA):
128
- db = retrieve_chroma()
129
- elif (rag_option == RAG_MONGODB):
130
- db = retrieve_mongodb()
131
-
132
- rag_chain = RetrievalQA.from_chain_type(
133
- llm,
134
- chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT,
135
- "verbose": True},
136
- retriever = db.as_retriever(search_kwargs = {"k": config["k"]}),
137
- return_source_documents = True)
138
-
139
- with get_openai_callback() as cb:
140
- completion = rag_chain({"query": prompt})
141
-
142
- return completion, rag_chain, cb