bstraehle commited on
Commit
12d440a
·
1 Parent(s): 166113c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -16
app.py CHANGED
@@ -120,9 +120,11 @@ def rag_chain(llm, prompt, db):
120
  completion = rag_chain({"query": prompt})
121
  return completion, rag_chain
122
 
123
- def wandb_trace(rag_option, prompt, completion, chain, status_msg, start_time_ms, end_time_ms):
124
- if (rag_option == RAG_OFF or str(status_msg) != ""):
125
- result = completion
 
 
126
  else:
127
  result = completion["result"]
128
  docs_meta = str([doc.metadata for doc in completion["source_documents"]])
@@ -130,8 +132,8 @@ def wandb_trace(rag_option, prompt, completion, chain, status_msg, start_time_ms
130
  trace = Trace(
131
  kind = "chain",
132
  name = type(chain).__name__ if (chain != None) else "",
133
- status_code = "success" if (str(status_msg) == "") else "error",
134
- status_message = str(status_msg),
135
  metadata = {
136
  "chunk_overlap": "" if (rag_option == RAG_OFF) else config["chunk_overlap"],
137
  "chunk_size": "" if (rag_option == RAG_OFF) else config["chunk_size"],
@@ -139,11 +141,12 @@ def wandb_trace(rag_option, prompt, completion, chain, status_msg, start_time_ms
139
  "model": config["model"],
140
  "temperature": config["temperature"],
141
  },
142
- inputs = {"rag_option": rag_option if (str(status_msg) == "") else "",
143
- "prompt": str(prompt if (str(status_msg) == "") else ""),
144
- "prompt_template": str((llm_template if (rag_option == RAG_OFF) else rag_template) if (str(status_msg) == "") else ""),
145
- "docs_meta": "" if (rag_option == RAG_OFF or str(status_msg) != "") else docs_meta},
146
- outputs = {"result": result},
 
147
  start_time_ms = start_time_ms,
148
  end_time_ms = end_time_ms
149
  )
@@ -160,7 +163,7 @@ def invoke(openai_api_key, rag_option, prompt):
160
  completion = ""
161
  result = ""
162
  chain = None
163
- status_msg = ""
164
  try:
165
  start_time_ms = round(time.time() * 1000)
166
  llm = ChatOpenAI(model_name = config["model"],
@@ -179,15 +182,14 @@ def invoke(openai_api_key, rag_option, prompt):
179
  completion, chain = rag_chain(llm, prompt, db)
180
  result = completion["result"]
181
  else:
182
- result, chain = llm_chain(llm, prompt)
183
- print(result)
184
- completion = result
185
  except Exception as e:
186
- status_msg = e
187
  raise gr.Error(e)
188
  finally:
189
  end_time_ms = round(time.time() * 1000)
190
- wandb_trace(rag_option, prompt, completion, chain, status_msg, start_time_ms, end_time_ms)
191
  return result
192
 
193
  gr.close_all()
 
120
  completion = rag_chain({"query": prompt})
121
  return completion, rag_chain
122
 
123
+ def wandb_trace(rag_option, prompt, completion, chain, err_msg, start_time_ms, end_time_ms):
124
+ if (str(err_msg) != ""):
125
+ result = ""
126
+ else if (rag_option == RAG_OFF):
127
+ result = completion.text
128
  else:
129
  result = completion["result"]
130
  docs_meta = str([doc.metadata for doc in completion["source_documents"]])
 
132
  trace = Trace(
133
  kind = "chain",
134
  name = type(chain).__name__ if (chain != None) else "",
135
+ status_code = "success" if (str(err_msg) == "") else "error",
136
+ status_message = str(err_msg),
137
  metadata = {
138
  "chunk_overlap": "" if (rag_option == RAG_OFF) else config["chunk_overlap"],
139
  "chunk_size": "" if (rag_option == RAG_OFF) else config["chunk_size"],
 
141
  "model": config["model"],
142
  "temperature": config["temperature"],
143
  },
144
+ inputs = {"rag_option": rag_option if (str(err_msg) == "") else "",
145
+ "prompt": str(prompt if (str(err_msg) == "") else ""),
146
+ "prompt_template": str((llm_template if (rag_option == RAG_OFF) else rag_template) if (str(err_msg) == "") else ""),
147
+ "docs_meta": "" if (str(err_msg) != "" or rag_option == RAG_OFF) else docs_meta},
148
+ outputs = {"result": result,
149
+ "completion": completion},
150
  start_time_ms = start_time_ms,
151
  end_time_ms = end_time_ms
152
  )
 
163
  completion = ""
164
  result = ""
165
  chain = None
166
+ err_msg = ""
167
  try:
168
  start_time_ms = round(time.time() * 1000)
169
  llm = ChatOpenAI(model_name = config["model"],
 
182
  completion, chain = rag_chain(llm, prompt, db)
183
  result = completion["result"]
184
  else:
185
+ completion, chain = llm_chain(llm, prompt)
186
+ result = completion.text
 
187
  except Exception as e:
188
+ err_msg = e
189
  raise gr.Error(e)
190
  finally:
191
  end_time_ms = round(time.time() * 1000)
192
+ wandb_trace(rag_option, prompt, completion, chain, err_msg, start_time_ms, end_time_ms)
193
  return result
194
 
195
  gr.close_all()