bstraehle commited on
Commit
1e76fb4
·
1 Parent(s): ba1f19c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -119,7 +119,7 @@ def rag_chain(llm, prompt, db):
119
  completion = rag_chain({"query": prompt})
120
  return completion, rag_chain
121
 
122
- def wandb_trace(rag_option, prompt, completion, result, llm_output, chain, err_msg, start_time_ms, end_time_ms):
123
  wandb.init(project = "openai-llm-rag")
124
  trace = Trace(
125
  kind = "chain",
@@ -134,6 +134,7 @@ def wandb_trace(rag_option, prompt, completion, result, llm_output, chain, err_m
134
  "prompt": prompt if (str(err_msg) == "") else "",
135
  },
136
  outputs = {"result": result if (str(err_msg) == "") else "",
 
137
  "llm_output": llm_output if (str(err_msg) == "") else "",
138
  "completion": str(completion) if (str(err_msg) == "") else "",
139
  },
@@ -186,7 +187,8 @@ def invoke(openai_api_key, rag_option, prompt):
186
  completion, chain = llm_chain(llm, prompt)
187
  if (completion.generations[0] != None and completion.generations[0][0] != None):
188
  result = completion.generations[0][0].text
189
- #llm_output = completion.generations[0][0].llm_output
 
190
  print(completion)
191
  print("###")
192
  print(completion.generations[0])
@@ -195,7 +197,7 @@ def invoke(openai_api_key, rag_option, prompt):
195
  raise gr.Error(e)
196
  finally:
197
  end_time_ms = round(time.time() * 1000)
198
- wandb_trace(rag_option, prompt, completion, result, llm_output, chain, err_msg, start_time_ms, end_time_ms)
199
  return result
200
 
201
  gr.close_all()
 
119
  completion = rag_chain({"query": prompt})
120
  return completion, rag_chain
121
 
122
+ def wandb_trace(rag_option, prompt, completion, result, generation_info, llm_output, chain, err_msg, start_time_ms, end_time_ms):
123
  wandb.init(project = "openai-llm-rag")
124
  trace = Trace(
125
  kind = "chain",
 
134
  "prompt": prompt if (str(err_msg) == "") else "",
135
  },
136
  outputs = {"result": result if (str(err_msg) == "") else "",
137
+ "generation_info": generation_info if (str(err_msg) == "") else "",
138
  "llm_output": llm_output if (str(err_msg) == "") else "",
139
  "completion": str(completion) if (str(err_msg) == "") else "",
140
  },
 
187
  completion, chain = llm_chain(llm, prompt)
188
  if (completion.generations[0] != None and completion.generations[0][0] != None):
189
  result = completion.generations[0][0].text
190
+ generation_info = completion.generations[0][0].generation_info
191
+ llm_output = completion.generations.llm_output
192
  print(completion)
193
  print("###")
194
  print(completion.generations[0])
 
197
  raise gr.Error(e)
198
  finally:
199
  end_time_ms = round(time.time() * 1000)
200
+ wandb_trace(rag_option, prompt, completion, result, generation_info, llm_output, chain, err_msg, start_time_ms, end_time_ms)
201
  return result
202
 
203
  gr.close_all()