Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -119,7 +119,7 @@ def rag_chain(llm, prompt, db):
|
|
119 |
completion = rag_chain({"query": prompt})
|
120 |
return completion, rag_chain
|
121 |
|
122 |
-
def wandb_trace(rag_option, prompt, completion, result, llm_output, chain, err_msg, start_time_ms, end_time_ms):
|
123 |
wandb.init(project = "openai-llm-rag")
|
124 |
trace = Trace(
|
125 |
kind = "chain",
|
@@ -134,6 +134,7 @@ def wandb_trace(rag_option, prompt, completion, result, llm_output, chain, err_m
|
|
134 |
"prompt": prompt if (str(err_msg) == "") else "",
|
135 |
},
|
136 |
outputs = {"result": result if (str(err_msg) == "") else "",
|
|
|
137 |
"llm_output": llm_output if (str(err_msg) == "") else "",
|
138 |
"completion": str(completion) if (str(err_msg) == "") else "",
|
139 |
},
|
@@ -186,7 +187,8 @@ def invoke(openai_api_key, rag_option, prompt):
|
|
186 |
completion, chain = llm_chain(llm, prompt)
|
187 |
if (completion.generations[0] != None and completion.generations[0][0] != None):
|
188 |
result = completion.generations[0][0].text
|
189 |
-
|
|
|
190 |
print(completion)
|
191 |
print("###")
|
192 |
print(completion.generations[0])
|
@@ -195,7 +197,7 @@ def invoke(openai_api_key, rag_option, prompt):
|
|
195 |
raise gr.Error(e)
|
196 |
finally:
|
197 |
end_time_ms = round(time.time() * 1000)
|
198 |
-
wandb_trace(rag_option, prompt, completion, result, llm_output, chain, err_msg, start_time_ms, end_time_ms)
|
199 |
return result
|
200 |
|
201 |
gr.close_all()
|
|
|
119 |
completion = rag_chain({"query": prompt})
|
120 |
return completion, rag_chain
|
121 |
|
122 |
+
def wandb_trace(rag_option, prompt, completion, result, generation_info, llm_output, chain, err_msg, start_time_ms, end_time_ms):
|
123 |
wandb.init(project = "openai-llm-rag")
|
124 |
trace = Trace(
|
125 |
kind = "chain",
|
|
|
134 |
"prompt": prompt if (str(err_msg) == "") else "",
|
135 |
},
|
136 |
outputs = {"result": result if (str(err_msg) == "") else "",
|
137 |
+
"generation_info": generation_info if (str(err_msg) == "") else "",
|
138 |
"llm_output": llm_output if (str(err_msg) == "") else "",
|
139 |
"completion": str(completion) if (str(err_msg) == "") else "",
|
140 |
},
|
|
|
187 |
completion, chain = llm_chain(llm, prompt)
|
188 |
if (completion.generations[0] != None and completion.generations[0][0] != None):
|
189 |
result = completion.generations[0][0].text
|
190 |
+
generation_info = completion.generations[0][0].generation_info
|
191 |
+
llm_output = completion.generations.llm_output
|
192 |
print(completion)
|
193 |
print("###")
|
194 |
print(completion.generations[0])
|
|
|
197 |
raise gr.Error(e)
|
198 |
finally:
|
199 |
end_time_ms = round(time.time() * 1000)
|
200 |
+
wandb_trace(rag_option, prompt, completion, result, generation_info, llm_output, chain, err_msg, start_time_ms, end_time_ms)
|
201 |
return result
|
202 |
|
203 |
gr.close_all()
|