Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -106,7 +106,7 @@ def document_retrieval_mongodb(llm, prompt):
|
|
106 |
def llm_chain(llm, prompt):
|
107 |
llm_chain = LLMChain(llm = llm, prompt = LLM_CHAIN_PROMPT)
|
108 |
completion = llm_chain.run({"question": prompt})
|
109 |
-
return completion
|
110 |
|
111 |
def rag_chain(llm, prompt, db):
|
112 |
rag_chain = RetrievalQA.from_chain_type(llm,
|
@@ -114,9 +114,9 @@ def rag_chain(llm, prompt, db):
|
|
114 |
retriever = db.as_retriever(search_kwargs = {"k": config["k"]}),
|
115 |
return_source_documents = True)
|
116 |
completion = rag_chain({"query": prompt})
|
117 |
-
return completion
|
118 |
|
119 |
-
def wandb_trace(rag_option, prompt, completion, status_msg, start_time_ms, end_time_ms):
|
120 |
wandb.init(project = "openai-llm-rag")
|
121 |
if (rag_option == "Off" or str(status_msg) != ""):
|
122 |
result = completion
|
@@ -126,7 +126,7 @@ def wandb_trace(rag_option, prompt, completion, status_msg, start_time_ms, end_t
|
|
126 |
document_1 = completion["source_documents"][1]
|
127 |
document_2 = completion["source_documents"][2]
|
128 |
trace = Trace(
|
129 |
-
kind = "
|
130 |
name = "LLMChain" if (rag_option == "Off") else "RetrievalQA",
|
131 |
status_code = "SUCCESS" if (str(status_msg) == "") else "ERROR",
|
132 |
status_message = str(status_msg),
|
@@ -145,7 +145,8 @@ def wandb_trace(rag_option, prompt, completion, status_msg, start_time_ms, end_t
|
|
145 |
"document_2": "" if (rag_option == "Off" or str(status_msg) != "") else str(document_2)},
|
146 |
outputs = {"result": result},
|
147 |
start_time_ms = start_time_ms,
|
148 |
-
end_time_ms = end_time_ms
|
|
|
149 |
)
|
150 |
trace.log("test")
|
151 |
wandb.finish()
|
@@ -159,6 +160,7 @@ def invoke(openai_api_key, rag_option, prompt):
|
|
159 |
raise gr.Error("Prompt is required.")
|
160 |
completion = ""
|
161 |
result = ""
|
|
|
162 |
status_msg = ""
|
163 |
try:
|
164 |
start_time_ms = round(time.time() * 1000)
|
@@ -169,23 +171,23 @@ def invoke(openai_api_key, rag_option, prompt):
|
|
169 |
#splits = document_loading_splitting()
|
170 |
#document_storage_chroma(splits)
|
171 |
db = document_retrieval_chroma(llm, prompt)
|
172 |
-
completion = rag_chain(llm, prompt, db)
|
173 |
result = completion["result"]
|
174 |
elif (rag_option == "MongoDB"):
|
175 |
#splits = document_loading_splitting()
|
176 |
#document_storage_mongodb(splits)
|
177 |
db = document_retrieval_mongodb(llm, prompt)
|
178 |
-
completion = rag_chain(llm, prompt, db)
|
179 |
result = completion["result"]
|
180 |
else:
|
181 |
-
result = llm_chain(llm, prompt)
|
182 |
completion = result
|
183 |
except Exception as e:
|
184 |
status_msg = e
|
185 |
raise gr.Error(e)
|
186 |
finally:
|
187 |
end_time_ms = round(time.time() * 1000)
|
188 |
-
wandb_trace(rag_option, prompt, completion, status_msg, start_time_ms, end_time_ms)
|
189 |
return result
|
190 |
|
191 |
gr.close_all()
|
|
|
106 |
def llm_chain(llm, prompt):
|
107 |
llm_chain = LLMChain(llm = llm, prompt = LLM_CHAIN_PROMPT)
|
108 |
completion = llm_chain.run({"question": prompt})
|
109 |
+
return completion, llm_chain
|
110 |
|
111 |
def rag_chain(llm, prompt, db):
|
112 |
rag_chain = RetrievalQA.from_chain_type(llm,
|
|
|
114 |
retriever = db.as_retriever(search_kwargs = {"k": config["k"]}),
|
115 |
return_source_documents = True)
|
116 |
completion = rag_chain({"query": prompt})
|
117 |
+
return completion, rag_chain
|
118 |
|
119 |
+
def wandb_trace(rag_option, prompt, completion, chain, status_msg, start_time_ms, end_time_ms):
|
120 |
wandb.init(project = "openai-llm-rag")
|
121 |
if (rag_option == "Off" or str(status_msg) != ""):
|
122 |
result = completion
|
|
|
126 |
document_1 = completion["source_documents"][1]
|
127 |
document_2 = completion["source_documents"][2]
|
128 |
trace = Trace(
|
129 |
+
kind = "chain",
|
130 |
name = "LLMChain" if (rag_option == "Off") else "RetrievalQA",
|
131 |
status_code = "SUCCESS" if (str(status_msg) == "") else "ERROR",
|
132 |
status_message = str(status_msg),
|
|
|
145 |
"document_2": "" if (rag_option == "Off" or str(status_msg) != "") else str(document_2)},
|
146 |
outputs = {"result": result},
|
147 |
start_time_ms = start_time_ms,
|
148 |
+
end_time_ms = end_time_ms,
|
149 |
+
model_dict={"chain": chain}
|
150 |
)
|
151 |
trace.log("test")
|
152 |
wandb.finish()
|
|
|
160 |
raise gr.Error("Prompt is required.")
|
161 |
completion = ""
|
162 |
result = ""
|
163 |
+
chain = ""
|
164 |
status_msg = ""
|
165 |
try:
|
166 |
start_time_ms = round(time.time() * 1000)
|
|
|
171 |
#splits = document_loading_splitting()
|
172 |
#document_storage_chroma(splits)
|
173 |
db = document_retrieval_chroma(llm, prompt)
|
174 |
+
completion, chain = rag_chain(llm, prompt, db)
|
175 |
result = completion["result"]
|
176 |
elif (rag_option == "MongoDB"):
|
177 |
#splits = document_loading_splitting()
|
178 |
#document_storage_mongodb(splits)
|
179 |
db = document_retrieval_mongodb(llm, prompt)
|
180 |
+
completion, chain = rag_chain(llm, prompt, db)
|
181 |
result = completion["result"]
|
182 |
else:
|
183 |
+
result, chain = llm_chain(llm, prompt)
|
184 |
completion = result
|
185 |
except Exception as e:
|
186 |
status_msg = e
|
187 |
raise gr.Error(e)
|
188 |
finally:
|
189 |
end_time_ms = round(time.time() * 1000)
|
190 |
+
wandb_trace(rag_option, prompt, completion, chain, status_msg, start_time_ms, end_time_ms)
|
191 |
return result
|
192 |
|
193 |
gr.close_all()
|