bstraehle commited on
Commit
ede11c3
·
1 Parent(s): 8546e9f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -119,7 +119,7 @@ def rag_chain(llm, prompt, db):
119
  completion = rag_chain({"query": prompt})
120
  return completion, rag_chain
121
 
122
- def wandb_trace(rag_option, prompt, completion, result, chain, err_msg, start_time_ms, end_time_ms):
123
  wandb.init(project = "openai-llm-rag")
124
  trace = Trace(
125
  kind = "chain",
@@ -134,14 +134,15 @@ def wandb_trace(rag_option, prompt, completion, result, chain, err_msg, start_ti
134
  "prompt": prompt if (str(err_msg) == "") else "",
135
  },
136
  outputs = {"result": result if (str(err_msg) == "") else "",
 
137
  "completion": str(completion) if (str(err_msg) == "") else "",
138
  },
139
  model_dict = {"llm_client": (str(chain.llm.client) if (rag_option == RAG_OFF) else
140
  str(chain.combine_documents_chain.llm_chain.llm.client)) if (str(err_msg) == "") else "",
141
  "llm_model_name": (str(chain.llm.model_name) if (rag_option == RAG_OFF) else
142
  str(chain.combine_documents_chain.llm_chain.llm.model_name)) if (str(err_msg) == "") else "",
143
- "llm_temperature": (chain.llm.temperature if (rag_option == RAG_OFF) else
144
- chain.combine_documents_chain.llm_chain.llm.temperature) if (str(err_msg) == "") else "",
145
  "chain_prompt": (str(chain.prompt) if (rag_option == RAG_OFF) else
146
  str(chain.combine_documents_chain.llm_chain.prompt)) if (str(err_msg) == "") else "",
147
  "chain_retriever": ("" if (rag_option == RAG_OFF) else str(chain.retriever)) if (str(err_msg) == "") else "",
@@ -162,6 +163,7 @@ def invoke(openai_api_key, rag_option, prompt):
162
  chain = None
163
  completion = ""
164
  result = ""
 
165
  err_msg = ""
166
  try:
167
  start_time_ms = round(time.time() * 1000)
@@ -182,14 +184,15 @@ def invoke(openai_api_key, rag_option, prompt):
182
  result = completion["result"]
183
  else:
184
  completion, chain = llm_chain(llm, prompt)
185
- result = completion.generations[0][0].text if (completion.generations[0] != None and
186
- completion.generations[0][0] != None) else ""
 
187
  except Exception as e:
188
  err_msg = e
189
  raise gr.Error(e)
190
  finally:
191
  end_time_ms = round(time.time() * 1000)
192
- wandb_trace(rag_option, prompt, completion, result, chain, err_msg, start_time_ms, end_time_ms)
193
  return result
194
 
195
  gr.close_all()
 
119
  completion = rag_chain({"query": prompt})
120
  return completion, rag_chain
121
 
122
+ def wandb_trace(rag_option, prompt, completion, result, llm_output, chain, err_msg, start_time_ms, end_time_ms):
123
  wandb.init(project = "openai-llm-rag")
124
  trace = Trace(
125
  kind = "chain",
 
134
  "prompt": prompt if (str(err_msg) == "") else "",
135
  },
136
  outputs = {"result": result if (str(err_msg) == "") else "",
137
+ "llm_output": llm_output if (str(err_msg) == "") else "",
138
  "completion": str(completion) if (str(err_msg) == "") else "",
139
  },
140
  model_dict = {"llm_client": (str(chain.llm.client) if (rag_option == RAG_OFF) else
141
  str(chain.combine_documents_chain.llm_chain.llm.client)) if (str(err_msg) == "") else "",
142
  "llm_model_name": (str(chain.llm.model_name) if (rag_option == RAG_OFF) else
143
  str(chain.combine_documents_chain.llm_chain.llm.model_name)) if (str(err_msg) == "") else "",
144
+ "llm_temperature": (str(chain.llm.temperature) if (rag_option == RAG_OFF) else
145
+ str(chain.combine_documents_chain.llm_chain.llm.temperature)) if (str(err_msg) == "") else "",
146
  "chain_prompt": (str(chain.prompt) if (rag_option == RAG_OFF) else
147
  str(chain.combine_documents_chain.llm_chain.prompt)) if (str(err_msg) == "") else "",
148
  "chain_retriever": ("" if (rag_option == RAG_OFF) else str(chain.retriever)) if (str(err_msg) == "") else "",
 
163
  chain = None
164
  completion = ""
165
  result = ""
166
+ llm_output = ""
167
  err_msg = ""
168
  try:
169
  start_time_ms = round(time.time() * 1000)
 
184
  result = completion["result"]
185
  else:
186
  completion, chain = llm_chain(llm, prompt)
187
+ if (completion.generations[0] != None and completion.generations[0][0] != None):
188
+ result = completion.generations[0][0].text
189
+ llm_output = completion.generations[0][0].llm_output
190
  except Exception as e:
191
  err_msg = e
192
  raise gr.Error(e)
193
  finally:
194
  end_time_ms = round(time.time() * 1000)
195
+ wandb_trace(rag_option, prompt, completion, result, llm_output, chain, err_msg, start_time_ms, end_time_ms)
196
  return result
197
 
198
  gr.close_all()