bstraehle commited on
Commit
4896967
·
1 Parent(s): f166d62

Update rag_langchain.py

Browse files
Files changed (1) hide show
  1. rag_langchain.py +108 -105
rag_langchain.py CHANGED
@@ -15,121 +15,124 @@ from langchain.vectorstores import MongoDBAtlasVectorSearch
15
 
16
  from pymongo import MongoClient
17
 
18
- PDF_URL = "https://arxiv.org/pdf/2303.08774.pdf"
19
- WEB_URL = "https://openai.com/research/gpt-4"
20
- YOUTUBE_URL_1 = "https://www.youtube.com/watch?v=--khbXchTeE"
21
- YOUTUBE_URL_2 = "https://www.youtube.com/watch?v=hdhZwyf24mE"
22
-
23
- CHROMA_DIR = "/data/db"
24
- YOUTUBE_DIR = "/data/yt"
25
-
26
- MONGODB_ATLAS_CLUSTER_URI = os.environ["MONGODB_ATLAS_CLUSTER_URI"]
27
- MONGODB_DB_NAME = "langchain_db"
28
- MONGODB_COLLECTION_NAME = "gpt-4"
29
- MONGODB_INDEX_NAME = "default"
30
-
31
- LLM_CHAIN_PROMPT = PromptTemplate(
32
- input_variables = ["question"],
33
- template = os.environ["LLM_TEMPLATE"])
34
- RAG_CHAIN_PROMPT = PromptTemplate(
35
- input_variables = ["context", "question"],
36
- template = os.environ["RAG_TEMPLATE"])
37
-
38
- logging.basicConfig(stream = sys.stdout, level = logging.INFO)
39
- logging.getLogger().addHandler(logging.StreamHandler(stream = sys.stdout))
40
-
41
- def load_documents():
42
- docs = []
 
 
 
43
 
44
- # PDF
45
- loader = PyPDFLoader(PDF_URL)
46
- docs.extend(loader.load())
47
- #print("docs = " + str(len(docs)))
48
 
49
- # Web
50
- loader = WebBaseLoader(WEB_URL)
51
- docs.extend(loader.load())
52
- #print("docs = " + str(len(docs)))
53
 
54
- # YouTube
55
- loader = GenericLoader(
56
- YoutubeAudioLoader(
57
- [YOUTUBE_URL_1, YOUTUBE_URL_2],
58
- YOUTUBE_DIR),
59
- OpenAIWhisperParser())
60
- docs.extend(loader.load())
61
- #print("docs = " + str(len(docs)))
62
 
63
- return docs
64
 
65
- def split_documents(config, docs):
66
- text_splitter = RecursiveCharacterTextSplitter()
67
 
68
- return text_splitter.split_documents(docs)
69
 
70
- def store_documents_chroma(chunks):
71
- Chroma.from_documents(
72
- documents = chunks,
73
- embedding = OpenAIEmbeddings(disallowed_special = ()),
74
- persist_directory = CHROMA_DIR)
75
-
76
- def store_documents_mongodb(chunks):
77
- client = MongoClient(MONGODB_ATLAS_CLUSTER_URI)
78
- collection = client[MONGODB_DB_NAME][MONGODB_COLLECTION_NAME]
79
-
80
- MongoDBAtlasVectorSearch.from_documents(
81
- documents = chunks,
82
- embedding = OpenAIEmbeddings(disallowed_special = ()),
83
- collection = collection,
84
- index_name = MONGODB_INDEX_NAME)
85
-
86
- def rag_ingestion_langchain(config):
87
- docs = load_documents()
88
 
89
- chunks = split_documents(config, docs)
90
 
91
- #store_documents_chroma(chunks)
92
- store_documents_mongodb(chunks)
93
-
94
- def get_vector_store_chroma():
95
- return Chroma(
96
- embedding_function = OpenAIEmbeddings(disallowed_special = ()),
97
- persist_directory = CHROMA_DIR)
98
-
99
- def get_vector_store_mongodb():
100
- return MongoDBAtlasVectorSearch.from_connection_string(
101
- MONGODB_ATLAS_CLUSTER_URI,
102
- MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
103
- OpenAIEmbeddings(disallowed_special = ()),
104
- index_name = MONGODB_INDEX_NAME)
105
-
106
- def get_llm(config):
107
- return ChatOpenAI(
108
- model_name = config["model_name"],
109
- temperature = config["temperature"])
110
-
111
- def llm_chain(config, prompt):
112
- llm_chain = LLMChain(
113
- llm = get_llm(config),
114
- prompt = LLM_CHAIN_PROMPT)
115
 
116
- with get_openai_callback() as callback:
117
- completion = llm_chain.generate([{"question": prompt}])
118
 
119
- return completion, llm_chain, callback
120
-
121
- def rag_chain(config, prompt):
122
- #vector_store = get_vector_store_chroma()
123
- vector_store = get_vector_store_mongodb()
124
-
125
- rag_chain = RetrievalQA.from_chain_type(
126
- get_llm(config),
127
- chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT,
128
- "verbose": True},
129
- retriever = vector_store.as_retriever(search_kwargs = {"k": config["k"]}),
130
- return_source_documents = True)
131
 
132
- with get_openai_callback() as callback:
133
- completion = rag_chain({"query": prompt})
134
 
135
- return completion, rag_chain, callback
 
15
 
16
  from pymongo import MongoClient
17
 
18
+ #PDF_URL = "https://arxiv.org/pdf/2303.08774.pdf"
19
+ #WEB_URL = "https://openai.com/research/gpt-4"
20
+ #YOUTUBE_URL_1 = "https://www.youtube.com/watch?v=--khbXchTeE"
21
+ #YOUTUBE_URL_2 = "https://www.youtube.com/watch?v=hdhZwyf24mE"
22
+
23
+ #MONGODB_ATLAS_CLUSTER_URI = os.environ["MONGODB_ATLAS_CLUSTER_URI"]
24
+ #MONGODB_DB_NAME = "langchain_db"
25
+ #MONGODB_COLLECTION_NAME = "gpt-4"
26
+ #MONGODB_INDEX_NAME = "default"
27
+
28
+ #logging.basicConfig(stream = sys.stdout, level = logging.INFO)
29
+ #logging.getLogger().addHandler(logging.StreamHandler(stream = sys.stdout))
30
+
31
+ class LangChainRAG:
32
+ MONGODB_DB_NAME = "langchain_db"
33
+
34
+ CHROMA_DIR = "/data/db"
35
+ YOUTUBE_DIR = "/data/yt"
36
+
37
+ LLM_CHAIN_PROMPT = PromptTemplate(
38
+ input_variables = ["question"],
39
+ template = os.environ["LLM_TEMPLATE"])
40
+ RAG_CHAIN_PROMPT = PromptTemplate(
41
+ input_variables = ["context", "question"],
42
+ template = os.environ["RAG_TEMPLATE"])
43
+
44
+ def load_documents(self):
45
+ docs = []
46
 
47
+ # PDF
48
+ loader = PyPDFLoader(self.PDF_URL)
49
+ docs.extend(loader.load())
50
+ #print("docs = " + str(len(docs)))
51
 
52
+ # Web
53
+ loader = WebBaseLoader(self.WEB_URL)
54
+ docs.extend(loader.load())
55
+ #print("docs = " + str(len(docs)))
56
 
57
+ # YouTube
58
+ loader = GenericLoader(
59
+ YoutubeAudioLoader(
60
+ [self.YOUTUBE_URL_1, self.YOUTUBE_URL_2],
61
+ self.YOUTUBE_DIR),
62
+ OpenAIWhisperParser())
63
+ docs.extend(loader.load())
64
+ #print("docs = " + str(len(docs)))
65
 
66
+ return docs
67
 
68
+ def split_documents(self, config, docs):
69
+ text_splitter = RecursiveCharacterTextSplitter()
70
 
71
+ return text_splitter.split_documents(docs)
72
 
73
+ def store_documents_chroma(self, chunks):
74
+ Chroma.from_documents(
75
+ documents = chunks,
76
+ embedding = OpenAIEmbeddings(disallowed_special = ()),
77
+ persist_directory = self.CHROMA_DIR)
78
+
79
+ def store_documents_mongodb(self, chunks):
80
+ client = MongoClient(self.MONGODB_ATLAS_CLUSTER_URI)
81
+ collection = client[self.MONGODB_DB_NAME][self.MONGODB_COLLECTION_NAME]
82
+
83
+ MongoDBAtlasVectorSearch.from_documents(
84
+ documents = chunks,
85
+ embedding = OpenAIEmbeddings(disallowed_special = ()),
86
+ collection = collection,
87
+ index_name = self.MONGODB_INDEX_NAME)
88
+
89
+ def rag_ingestion_langchain(self, config):
90
+ docs = self.load_documents()
91
 
92
+ chunks = self.split_documents(config, docs)
93
 
94
+ #self.store_documents_chroma(chunks)
95
+ self.store_documents_mongodb(chunks)
96
+
97
+ def get_vector_store_chroma(self):
98
+ return Chroma(
99
+ embedding_function = OpenAIEmbeddings(disallowed_special = ()),
100
+ persist_directory = self.CHROMA_DIR)
101
+
102
+ def get_vector_store_mongodb(self):
103
+ return MongoDBAtlasVectorSearch.from_connection_string(
104
+ self.MONGODB_ATLAS_CLUSTER_URI,
105
+ self.MONGODB_DB_NAME + "." + self.MONGODB_COLLECTION_NAME,
106
+ OpenAIEmbeddings(disallowed_special = ()),
107
+ index_name = self.MONGODB_INDEX_NAME)
108
+
109
+ def get_llm(self, config):
110
+ return ChatOpenAI(
111
+ model_name = config["model_name"],
112
+ temperature = config["temperature"])
113
+
114
+ def llm_chain(self, config, prompt):
115
+ llm_chain = LLMChain(
116
+ llm = self.get_llm(config),
117
+ prompt = self.LLM_CHAIN_PROMPT)
118
 
119
+ with get_openai_callback() as callback:
120
+ completion = llm_chain.generate([{"question": prompt}])
121
 
122
+ return completion, llm_chain, callback
123
+
124
+ def rag_chain(self, config, prompt):
125
+ #vector_store = self.get_vector_store_chroma()
126
+ vector_store = self.get_vector_store_mongodb()
127
+
128
+ rag_chain = RetrievalQA.from_chain_type(
129
+ self.get_llm(config),
130
+ chain_type_kwargs = {"prompt": self.RAG_CHAIN_PROMPT,
131
+ "verbose": True},
132
+ retriever = vector_store.as_retriever(search_kwargs = {"k": config["k"]}),
133
+ return_source_documents = True)
134
 
135
+ with get_openai_callback() as callback:
136
+ completion = rag_chain({"query": prompt})
137
 
138
+ return completion, rag_chain, callback