bstraehle commited on
Commit
e02bd6d
·
1 Parent(s): a857f12

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -38
app.py CHANGED
@@ -14,7 +14,8 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
14
  from langchain.vectorstores import Chroma
15
  from langchain.vectorstores import MongoDBAtlasVectorSearch
16
  from pymongo import MongoClient
17
- from wandb.sdk.data_types.trace_tree import Trace
 
18
 
19
  _ = load_dotenv(find_dotenv())
20
 
@@ -35,7 +36,7 @@ MONGODB_INDEX_NAME = "default"
35
  LLM_CHAIN_PROMPT = PromptTemplate(input_variables = ["question"], template = os.environ["LLM_TEMPLATE"])
36
  RAG_CHAIN_PROMPT = PromptTemplate(input_variables = ["context", "question"], template = os.environ["RAG_TEMPLATE"])
37
 
38
- WANDB_API_KEY = os.environ["WANDB_API_KEY"]
39
 
40
  RAG_OFF = "Off"
41
  RAG_CHROMA = "Chroma"
@@ -115,42 +116,42 @@ def rag_chain(llm, prompt, db):
115
  completion = rag_chain({"query": prompt})
116
  return completion, rag_chain
117
 
118
- def wandb_trace(rag_option, prompt, completion, result, generation_info, llm_output, chain, err_msg, start_time_ms, end_time_ms):
119
- wandb.init(project = "openai-llm-rag")
120
-
121
- trace = Trace(
122
- kind = "chain",
123
- name = "" if (chain == None) else type(chain).__name__,
124
- status_code = "success" if (str(err_msg) == "") else "error",
125
- status_message = str(err_msg),
126
- metadata = {"chunk_overlap": "" if (rag_option == RAG_OFF) else config["chunk_overlap"],
127
- "chunk_size": "" if (rag_option == RAG_OFF) else config["chunk_size"],
128
- } if (str(err_msg) == "") else {},
129
- inputs = {"rag_option": rag_option,
130
- "prompt": prompt,
131
- "chain_prompt": (str(chain.prompt) if (rag_option == RAG_OFF) else
132
- str(chain.combine_documents_chain.llm_chain.prompt)),
133
- "source_documents": "" if (rag_option == RAG_OFF) else str([doc.metadata["source"] for doc in completion["source_documents"]]),
134
- } if (str(err_msg) == "") else {},
135
- outputs = {"result": result,
136
- "generation_info": str(generation_info),
137
- "llm_output": str(llm_output),
138
- "completion": str(completion),
139
- } if (str(err_msg) == "") else {},
140
- model_dict = {"client": (str(chain.llm.client) if (rag_option == RAG_OFF) else
141
- str(chain.combine_documents_chain.llm_chain.llm.client)),
142
- "model_name": (str(chain.llm.model_name) if (rag_option == RAG_OFF) else
143
- str(chain.combine_documents_chain.llm_chain.llm.model_name)),
144
- "temperature": (str(chain.llm.temperature) if (rag_option == RAG_OFF) else
145
- str(chain.combine_documents_chain.llm_chain.llm.temperature)),
146
- "retriever": ("" if (rag_option == RAG_OFF) else str(chain.retriever)),
147
- } if (str(err_msg) == "") else {},
148
- start_time_ms = start_time_ms,
149
- end_time_ms = end_time_ms
150
- )
151
-
152
- trace.log("evaluation")
153
- wandb.finish()
154
 
155
  def invoke(openai_api_key, rag_option, prompt):
156
  if (openai_api_key == ""):
 
14
  from langchain.vectorstores import Chroma
15
  from langchain.vectorstores import MongoDBAtlasVectorSearch
16
  from pymongo import MongoClient
17
+ from trace import wandb_trace
18
+ #from wandb.sdk.data_types.trace_tree import Trace
19
 
20
  _ = load_dotenv(find_dotenv())
21
 
 
36
  LLM_CHAIN_PROMPT = PromptTemplate(input_variables = ["question"], template = os.environ["LLM_TEMPLATE"])
37
  RAG_CHAIN_PROMPT = PromptTemplate(input_variables = ["context", "question"], template = os.environ["RAG_TEMPLATE"])
38
 
39
+ #WANDB_API_KEY = os.environ["WANDB_API_KEY"]
40
 
41
  RAG_OFF = "Off"
42
  RAG_CHROMA = "Chroma"
 
116
  completion = rag_chain({"query": prompt})
117
  return completion, rag_chain
118
 
119
+ #def wandb_trace(rag_option, prompt, completion, result, generation_info, llm_output, chain, err_msg, start_time_ms, end_time_ms):
120
+ # wandb.init(project = "openai-llm-rag")
121
+ #
122
+ # trace = Trace(
123
+ # kind = "chain",
124
+ # name = "" if (chain == None) else type(chain).__name__,
125
+ # status_code = "success" if (str(err_msg) == "") else "error",
126
+ # status_message = str(err_msg),
127
+ # metadata = {"chunk_overlap": "" if (rag_option == RAG_OFF) else config["chunk_overlap"],
128
+ # "chunk_size": "" if (rag_option == RAG_OFF) else config["chunk_size"],
129
+ # } if (str(err_msg) == "") else {},
130
+ # inputs = {"rag_option": rag_option,
131
+ # "prompt": prompt,
132
+ # "chain_prompt": (str(chain.prompt) if (rag_option == RAG_OFF) else
133
+ # str(chain.combine_documents_chain.llm_chain.prompt)),
134
+ # "source_documents": "" if (rag_option == RAG_OFF) else str([doc.metadata["source"] for doc in completion["source_documents"]]),
135
+ # } if (str(err_msg) == "") else {},
136
+ # outputs = {"result": result,
137
+ # "generation_info": str(generation_info),
138
+ # "llm_output": str(llm_output),
139
+ # "completion": str(completion),
140
+ # } if (str(err_msg) == "") else {},
141
+ # model_dict = {"client": (str(chain.llm.client) if (rag_option == RAG_OFF) else
142
+ # str(chain.combine_documents_chain.llm_chain.llm.client)),
143
+ # "model_name": (str(chain.llm.model_name) if (rag_option == RAG_OFF) else
144
+ # str(chain.combine_documents_chain.llm_chain.llm.model_name)),
145
+ # "temperature": (str(chain.llm.temperature) if (rag_option == RAG_OFF) else
146
+ # str(chain.combine_documents_chain.llm_chain.llm.temperature)),
147
+ # "retriever": ("" if (rag_option == RAG_OFF) else str(chain.retriever)),
148
+ # } if (str(err_msg) == "") else {},
149
+ # start_time_ms = start_time_ms,
150
+ # end_time_ms = end_time_ms
151
+ # )
152
+ #
153
+ # trace.log("evaluation")
154
+ # wandb.finish()
155
 
156
  def invoke(openai_api_key, rag_option, prompt):
157
  if (openai_api_key == ""):