bstraehle commited on
Commit
3d0c9a5
·
1 Parent(s): 601accc

Create trace.py

Browse files
Files changed (1) hide show
  1. trace.py +48 -0
trace.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, wandb
2
+
3
+ from wandb.sdk.data_types.trace_tree import Trace
4
+
5
+ WANDB_API_KEY = os.environ["WANDB_API_KEY"]
6
+
7
+ AGENT_LANGCHAIN = "LangChain"
8
+ AGENT_LLAMAINDEX = "LlamaIndex"
9
+
10
+ def trace_wandb(config,
11
+ agent_option,
12
+ prompt,
13
+ completion,
14
+ result,
15
+ callback,
16
+ err_msg,
17
+ start_time_ms,
18
+ end_time_ms):
19
+ wandb.init(project = "openai-llm-agent")
20
+
21
+ if (agent_option == AGENT_LANGCHAIN):
22
+ prompt_template = os.environ["LANGCHAIN_TEMPLATE"]
23
+ elif (agent_option == AGENT_LLAMAINDEX):
24
+ prompt_template = os.environ["LLAMAINDEX_TEMPLATE"]
25
+ else:
26
+ prompt_template = os.environ["TEMPLATE"]
27
+
28
+ trace = Trace(
29
+ kind = "LLM",
30
+ name = "Real-Time Reasoning Application",
31
+ status_code = "success" if (str(err_msg) == "") else "error",
32
+ status_message = str(err_msg),
33
+ inputs = {"prompt": prompt,
34
+ "prompt_template": prompt_template,
35
+ "agent_option": agent_option,
36
+ "config": str(config)
37
+ } if (str(err_msg) == "") else {},
38
+ outputs = {"result": str(result),
39
+ "callback": str(callback),
40
+ "completion": str(completion)
41
+ } if (str(err_msg) == "") else {},
42
+ start_time_ms = start_time_ms,
43
+ end_time_ms = end_time_ms
44
+ )
45
+
46
+ trace.log("evaluation")
47
+
48
+ wandb.finish()