bstraehle commited on
Commit
03ab966
·
1 Parent(s): 1eb30b8

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +13 -11
rag.py CHANGED
@@ -1,4 +1,4 @@
1
- import os
2
 
3
  from langchain.chains import LLMChain, RetrievalQA
4
  from langchain.chat_models import ChatOpenAI
@@ -38,6 +38,8 @@ config = {
38
  "chunk_overlap": 150,
39
  "chunk_size": 1500,
40
  "k": 3,
 
 
41
  }
42
 
43
  def document_loading_splitting():
@@ -87,23 +89,23 @@ def document_retrieval_mongodb(llm, prompt):
87
  OpenAIEmbeddings(disallowed_special = ()),
88
  index_name = MONGODB_INDEX_NAME)
89
 
90
- def llm_chain(llm, prompt):
91
- llm_chain = LLMChain(llm = llm,
 
 
 
 
 
92
  prompt = LLM_CHAIN_PROMPT,
93
  verbose = False)
94
  completion = llm_chain.generate([{"question": prompt}])
95
  return completion, llm_chain
96
 
97
- def rag_chain(llm, prompt, db):
98
- rag_chain = RetrievalQA.from_chain_type(llm,
99
  chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
100
  retriever = db.as_retriever(search_kwargs = {"k": config["k"]}),
101
  return_source_documents = True,
102
  verbose = False)
103
  completion = rag_chain({"query": prompt})
104
- return completion, rag_chain
105
-
106
- def get_llm():
107
- return ChatOpenAI(model_name = config["model_name"],
108
- openai_api_key = openai_api_key,
109
- temperature = config["temperature"])
 
1
+ import openai, os
2
 
3
  from langchain.chains import LLMChain, RetrievalQA
4
  from langchain.chat_models import ChatOpenAI
 
38
  "chunk_overlap": 150,
39
  "chunk_size": 1500,
40
  "k": 3,
41
+ "model_name": "gpt-4-0613",
42
+ "temperature": 0,
43
  }
44
 
45
  def document_loading_splitting():
 
89
  OpenAIEmbeddings(disallowed_special = ()),
90
  index_name = MONGODB_INDEX_NAME)
91
 
92
+ def get_llm():
93
+ return ChatOpenAI(model_name = config["model_name"],
94
+ openai_api_key = openai_api_key,
95
+ temperature = config["temperature"])
96
+
97
+ def llm_chain(prompt):
98
+ llm_chain = LLMChain(llm = get_llm(),
99
  prompt = LLM_CHAIN_PROMPT,
100
  verbose = False)
101
  completion = llm_chain.generate([{"question": prompt}])
102
  return completion, llm_chain
103
 
104
+ def rag_chain(prompt, db):
105
+ rag_chain = RetrievalQA.from_chain_type(get_llm(),
106
  chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
107
  retriever = db.as_retriever(search_kwargs = {"k": config["k"]}),
108
  return_source_documents = True,
109
  verbose = False)
110
  completion = rag_chain({"query": prompt})
111
+ return completion, rag_chain