bstraehle commited on
Commit
08cc2d6
·
1 Parent(s): 99e5427

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +5 -5
rag.py CHANGED
@@ -81,13 +81,13 @@ def document_retrieval_mongodb(llm, prompt):
81
  OpenAIEmbeddings(disallowed_special = ()),
82
  index_name = MONGODB_INDEX_NAME)
83
 
84
- def get_llm(openai_api_key):
85
  return ChatOpenAI(model_name = config["model_name"],
86
  openai_api_key = openai_api_key,
87
  temperature = config["temperature"])
88
 
89
- def llm_chain(openai_api_key, prompt):
90
- llm_chain = LLMChain(llm = get_llm(openai_api_key),
91
  prompt = LLM_CHAIN_PROMPT,
92
  verbose = False)
93
 
@@ -95,8 +95,8 @@ def llm_chain(openai_api_key, prompt):
95
 
96
  return completion, llm_chain
97
 
98
- def rag_chain(openai_api_key, prompt):
99
- llm = get_llm(openai_api_key)
100
 
101
  db = document_retrieval_chroma(llm, prompt)
102
 
 
81
  OpenAIEmbeddings(disallowed_special = ()),
82
  index_name = MONGODB_INDEX_NAME)
83
 
84
+ def get_llm(config, openai_api_key):
85
  return ChatOpenAI(model_name = config["model_name"],
86
  openai_api_key = openai_api_key,
87
  temperature = config["temperature"])
88
 
89
+ def llm_chain(config, openai_api_key, prompt):
90
+ llm_chain = LLMChain(llm = get_llm(config, openai_api_key),
91
  prompt = LLM_CHAIN_PROMPT,
92
  verbose = False)
93
 
 
95
 
96
  return completion, llm_chain
97
 
98
+ def rag_chain(config, openai_api_key, prompt):
99
+ llm = get_llm(config, openai_api_key)
100
 
101
  db = document_retrieval_chroma(llm, prompt)
102