Spaces:
Build error
Build error
Update rag.py
Browse files
rag.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
import os
|
2 |
|
3 |
from langchain.chains import LLMChain, RetrievalQA
|
4 |
from langchain.chat_models import ChatOpenAI
|
@@ -38,6 +38,8 @@ config = {
|
|
38 |
"chunk_overlap": 150,
|
39 |
"chunk_size": 1500,
|
40 |
"k": 3,
|
|
|
|
|
41 |
}
|
42 |
|
43 |
def document_loading_splitting():
|
@@ -87,23 +89,23 @@ def document_retrieval_mongodb(llm, prompt):
|
|
87 |
OpenAIEmbeddings(disallowed_special = ()),
|
88 |
index_name = MONGODB_INDEX_NAME)
|
89 |
|
90 |
-
def
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
92 |
prompt = LLM_CHAIN_PROMPT,
|
93 |
verbose = False)
|
94 |
completion = llm_chain.generate([{"question": prompt}])
|
95 |
return completion, llm_chain
|
96 |
|
97 |
-
def rag_chain(
|
98 |
-
rag_chain = RetrievalQA.from_chain_type(
|
99 |
chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
|
100 |
retriever = db.as_retriever(search_kwargs = {"k": config["k"]}),
|
101 |
return_source_documents = True,
|
102 |
verbose = False)
|
103 |
completion = rag_chain({"query": prompt})
|
104 |
-
return completion, rag_chain
|
105 |
-
|
106 |
-
def get_llm():
|
107 |
-
return ChatOpenAI(model_name = config["model_name"],
|
108 |
-
openai_api_key = openai_api_key,
|
109 |
-
temperature = config["temperature"])
|
|
|
1 |
+
import openai, os
|
2 |
|
3 |
from langchain.chains import LLMChain, RetrievalQA
|
4 |
from langchain.chat_models import ChatOpenAI
|
|
|
38 |
"chunk_overlap": 150,
|
39 |
"chunk_size": 1500,
|
40 |
"k": 3,
|
41 |
+
"model_name": "gpt-4-0613",
|
42 |
+
"temperature": 0,
|
43 |
}
|
44 |
|
45 |
def document_loading_splitting():
|
|
|
89 |
OpenAIEmbeddings(disallowed_special = ()),
|
90 |
index_name = MONGODB_INDEX_NAME)
|
91 |
|
92 |
+
def get_llm():
|
93 |
+
return ChatOpenAI(model_name = config["model_name"],
|
94 |
+
openai_api_key = openai_api_key,
|
95 |
+
temperature = config["temperature"])
|
96 |
+
|
97 |
+
def llm_chain(prompt):
|
98 |
+
llm_chain = LLMChain(llm = get_llm(),
|
99 |
prompt = LLM_CHAIN_PROMPT,
|
100 |
verbose = False)
|
101 |
completion = llm_chain.generate([{"question": prompt}])
|
102 |
return completion, llm_chain
|
103 |
|
104 |
+
def rag_chain(prompt, db):
|
105 |
+
rag_chain = RetrievalQA.from_chain_type(get_llm(),
|
106 |
chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
|
107 |
retriever = db.as_retriever(search_kwargs = {"k": config["k"]}),
|
108 |
return_source_documents = True,
|
109 |
verbose = False)
|
110 |
completion = rag_chain({"query": prompt})
|
111 |
+
return completion, rag_chain
|
|
|
|
|
|
|
|
|
|