File size: 2,075 Bytes
e874531
fcd3a75
a857f12
 
9e0f123
 
301617d
b552593
86ffba3
 
 
 
c097673
86ffba3
 
 
a857f12
 
 
 
 
 
 
c962a65
a857f12
c962a65
a857f12
c962a65
a857f12
 
b4f203e
a857f12
 
c962a65
a857f12
c962a65
a857f12
c962a65
a857f12
c962a65
a857f12
 
 
 
 
 
86ffba3
a857f12
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import os, wandb

from wandb.sdk.data_types.trace_tree import Trace

WANDB_API_KEY = os.environ["WANDB_API_KEY"]

def trace_wandb(config,
                is_rag_off, 
                prompt, 
                completion, 
                result, 
                chain, 
                cb, 
                err_msg, 
                start_time_ms, 
                end_time_ms):
    wandb.init(project = "openai-llm-rag")
    
    trace = Trace(
        kind = "chain",
        name = "" if (chain == None) else type(chain).__name__,
        status_code = "success" if (str(err_msg) == "") else "error",
        status_message = str(err_msg),
        inputs = {"is_rag": not is_rag_off,
                  "prompt": prompt,
                  "chain_prompt": (str(chain.prompt) if (is_rag_off) else 
                                   str(chain.combine_documents_chain.llm_chain.prompt)),
                  "source_documents": "" if (is_rag_off) else str([doc.metadata["source"] for doc in completion["source_documents"]]),
                 } if (str(err_msg) == "") else {},
        outputs = {"result": result,
                   "cb": str(cb),
                   "completion": str(completion),
                  } if (str(err_msg) == "") else {},
        model_dict = {"client": (str(chain.llm.client) if (is_rag_off) else
                                 str(chain.combine_documents_chain.llm_chain.llm.client)),
                      "model_name": (str(chain.llm.model_name) if (is_rag_off) else
                                     str(chain.combine_documents_chain.llm_chain.llm.model_name)),
                      "temperature": (str(chain.llm.temperature) if (is_rag_off) else
                                      str(chain.combine_documents_chain.llm_chain.llm.temperature)),
                      "retriever": ("" if (is_rag_off) else str(chain.retriever)),
                     } if (str(err_msg) == "") else {},
        start_time_ms = start_time_ms,
        end_time_ms = end_time_ms
    )
    
    trace.log("evaluation")
                    
    wandb.finish()