bstraehle commited on
Commit
a857f12
·
1 Parent(s): 807ade8

Update trace.py

Browse files
Files changed (1) hide show
  1. trace.py +38 -1
trace.py CHANGED
@@ -1 +1,38 @@
1
- from wandb.sdk.data_types.trace_tree import Trace
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from wandb.sdk.data_types.trace_tree import Trace
2
+
3
+ def wandb_trace(rag_option, prompt, completion, result, generation_info, llm_output, chain, err_msg, start_time_ms, end_time_ms):
4
+ wandb.init(project = "openai-llm-rag")
5
+
6
+ trace = Trace(
7
+ kind = "chain",
8
+ name = "" if (chain == None) else type(chain).__name__,
9
+ status_code = "success" if (str(err_msg) == "") else "error",
10
+ status_message = str(err_msg),
11
+ metadata = {"chunk_overlap": "" if (rag_option == RAG_OFF) else config["chunk_overlap"],
12
+ "chunk_size": "" if (rag_option == RAG_OFF) else config["chunk_size"],
13
+ } if (str(err_msg) == "") else {},
14
+ inputs = {"rag_option": rag_option,
15
+ "prompt": prompt,
16
+ "chain_prompt": (str(chain.prompt) if (rag_option == RAG_OFF) else
17
+ str(chain.combine_documents_chain.llm_chain.prompt)),
18
+ "source_documents": "" if (rag_option == RAG_OFF) else str([doc.metadata["source"] for doc in completion["source_documents"]]),
19
+ } if (str(err_msg) == "") else {},
20
+ outputs = {"result": result,
21
+ "generation_info": str(generation_info),
22
+ "llm_output": str(llm_output),
23
+ "completion": str(completion),
24
+ } if (str(err_msg) == "") else {},
25
+ model_dict = {"client": (str(chain.llm.client) if (rag_option == RAG_OFF) else
26
+ str(chain.combine_documents_chain.llm_chain.llm.client)),
27
+ "model_name": (str(chain.llm.model_name) if (rag_option == RAG_OFF) else
28
+ str(chain.combine_documents_chain.llm_chain.llm.model_name)),
29
+ "temperature": (str(chain.llm.temperature) if (rag_option == RAG_OFF) else
30
+ str(chain.combine_documents_chain.llm_chain.llm.temperature)),
31
+ "retriever": ("" if (rag_option == RAG_OFF) else str(chain.retriever)),
32
+ } if (str(err_msg) == "") else {},
33
+ start_time_ms = start_time_ms,
34
+ end_time_ms = end_time_ms
35
+ )
36
+
37
+ trace.log("evaluation")
38
+ wandb.finish()