Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -119,7 +119,7 @@ def rag_chain(llm, prompt, db):
|
|
119 |
completion = rag_chain({"query": prompt})
|
120 |
return completion, rag_chain
|
121 |
|
122 |
-
def wandb_trace(rag_option, prompt, completion, result, chain, err_msg, start_time_ms, end_time_ms):
|
123 |
wandb.init(project = "openai-llm-rag")
|
124 |
trace = Trace(
|
125 |
kind = "chain",
|
@@ -134,14 +134,15 @@ def wandb_trace(rag_option, prompt, completion, result, chain, err_msg, start_ti
|
|
134 |
"prompt": prompt if (str(err_msg) == "") else "",
|
135 |
},
|
136 |
outputs = {"result": result if (str(err_msg) == "") else "",
|
|
|
137 |
"completion": str(completion) if (str(err_msg) == "") else "",
|
138 |
},
|
139 |
model_dict = {"llm_client": (str(chain.llm.client) if (rag_option == RAG_OFF) else
|
140 |
str(chain.combine_documents_chain.llm_chain.llm.client)) if (str(err_msg) == "") else "",
|
141 |
"llm_model_name": (str(chain.llm.model_name) if (rag_option == RAG_OFF) else
|
142 |
str(chain.combine_documents_chain.llm_chain.llm.model_name)) if (str(err_msg) == "") else "",
|
143 |
-
"llm_temperature": (chain.llm.temperature if (rag_option == RAG_OFF) else
|
144 |
-
chain.combine_documents_chain.llm_chain.llm.temperature) if (str(err_msg) == "") else "",
|
145 |
"chain_prompt": (str(chain.prompt) if (rag_option == RAG_OFF) else
|
146 |
str(chain.combine_documents_chain.llm_chain.prompt)) if (str(err_msg) == "") else "",
|
147 |
"chain_retriever": ("" if (rag_option == RAG_OFF) else str(chain.retriever)) if (str(err_msg) == "") else "",
|
@@ -162,6 +163,7 @@ def invoke(openai_api_key, rag_option, prompt):
|
|
162 |
chain = None
|
163 |
completion = ""
|
164 |
result = ""
|
|
|
165 |
err_msg = ""
|
166 |
try:
|
167 |
start_time_ms = round(time.time() * 1000)
|
@@ -182,14 +184,15 @@ def invoke(openai_api_key, rag_option, prompt):
|
|
182 |
result = completion["result"]
|
183 |
else:
|
184 |
completion, chain = llm_chain(llm, prompt)
|
185 |
-
|
186 |
-
|
|
|
187 |
except Exception as e:
|
188 |
err_msg = e
|
189 |
raise gr.Error(e)
|
190 |
finally:
|
191 |
end_time_ms = round(time.time() * 1000)
|
192 |
-
wandb_trace(rag_option, prompt, completion, result, chain, err_msg, start_time_ms, end_time_ms)
|
193 |
return result
|
194 |
|
195 |
gr.close_all()
|
|
|
119 |
completion = rag_chain({"query": prompt})
|
120 |
return completion, rag_chain
|
121 |
|
122 |
+
def wandb_trace(rag_option, prompt, completion, result, llm_output, chain, err_msg, start_time_ms, end_time_ms):
|
123 |
wandb.init(project = "openai-llm-rag")
|
124 |
trace = Trace(
|
125 |
kind = "chain",
|
|
|
134 |
"prompt": prompt if (str(err_msg) == "") else "",
|
135 |
},
|
136 |
outputs = {"result": result if (str(err_msg) == "") else "",
|
137 |
+
"llm_output": llm_output if (str(err_msg) == "") else "",
|
138 |
"completion": str(completion) if (str(err_msg) == "") else "",
|
139 |
},
|
140 |
model_dict = {"llm_client": (str(chain.llm.client) if (rag_option == RAG_OFF) else
|
141 |
str(chain.combine_documents_chain.llm_chain.llm.client)) if (str(err_msg) == "") else "",
|
142 |
"llm_model_name": (str(chain.llm.model_name) if (rag_option == RAG_OFF) else
|
143 |
str(chain.combine_documents_chain.llm_chain.llm.model_name)) if (str(err_msg) == "") else "",
|
144 |
+
"llm_temperature": (str(chain.llm.temperature) if (rag_option == RAG_OFF) else
|
145 |
+
str(chain.combine_documents_chain.llm_chain.llm.temperature)) if (str(err_msg) == "") else "",
|
146 |
"chain_prompt": (str(chain.prompt) if (rag_option == RAG_OFF) else
|
147 |
str(chain.combine_documents_chain.llm_chain.prompt)) if (str(err_msg) == "") else "",
|
148 |
"chain_retriever": ("" if (rag_option == RAG_OFF) else str(chain.retriever)) if (str(err_msg) == "") else "",
|
|
|
163 |
chain = None
|
164 |
completion = ""
|
165 |
result = ""
|
166 |
+
llm_output = ""
|
167 |
err_msg = ""
|
168 |
try:
|
169 |
start_time_ms = round(time.time() * 1000)
|
|
|
184 |
result = completion["result"]
|
185 |
else:
|
186 |
completion, chain = llm_chain(llm, prompt)
|
187 |
+
if (completion.generations[0] != None and completion.generations[0][0] != None):
|
188 |
+
result = completion.generations[0][0].text
|
189 |
+
llm_output = completion.generations[0][0].llm_output
|
190 |
except Exception as e:
|
191 |
err_msg = e
|
192 |
raise gr.Error(e)
|
193 |
finally:
|
194 |
end_time_ms = round(time.time() * 1000)
|
195 |
+
wandb_trace(rag_option, prompt, completion, result, llm_output, chain, err_msg, start_time_ms, end_time_ms)
|
196 |
return result
|
197 |
|
198 |
gr.close_all()
|