bstraehle commited on
Commit
2f29f7e
·
1 Parent(s): d4818ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -9
app.py CHANGED
@@ -30,7 +30,6 @@ MONGODB_INDEX_NAME = "default"
30
 
31
  config = {
32
  "model": "gpt-4",
33
- "rag_option": "Off",
34
  "temperature": 0,
35
  }
36
 
@@ -106,7 +105,6 @@ def document_retrieval_mongodb(llm, prompt):
106
  def llm_chain(llm, prompt):
107
  llm_chain = LLMChain(llm = llm, prompt = LLM_CHAIN_PROMPT)
108
  completion = llm_chain.run({"question": prompt})
109
- wandb.log({"prompt": prompt, "completion": completion})
110
  return completion
111
 
112
  def rag_chain(llm, prompt, db):
@@ -115,9 +113,7 @@ def rag_chain(llm, prompt, db):
115
  retriever = db.as_retriever(search_kwargs = {"k": 3}),
116
  return_source_documents = True)
117
  completion = rag_chain({"query": prompt})
118
- completion_result = completion["result"]
119
- wandb.log({"prompt": prompt, "completion": completion_result})
120
- return completion_result
121
 
122
  def invoke(openai_api_key, rag_option, prompt):
123
  if (openai_api_key == ""):
@@ -126,7 +122,6 @@ def invoke(openai_api_key, rag_option, prompt):
126
  raise gr.Error("Retrieval Augmented Generation is required.")
127
  if (prompt == ""):
128
  raise gr.Error("Prompt is required.")
129
- #wandb.config[rag_option] = rag_option
130
  try:
131
  llm = ChatOpenAI(model_name = config.model,
132
  openai_api_key = openai_api_key,
@@ -135,17 +130,19 @@ def invoke(openai_api_key, rag_option, prompt):
135
  #splits = document_loading_splitting()
136
  #document_storage_chroma(splits)
137
  db = document_retrieval_chroma(llm, prompt)
138
- result = rag_chain(llm, prompt, db)
139
  elif (rag_option == "MongoDB"):
140
  #splits = document_loading_splitting()
141
  #document_storage_mongodb(splits)
142
  db = document_retrieval_mongodb(llm, prompt)
143
- result = rag_chain(llm, prompt, db)
144
  else:
145
  result = llm_chain(llm, prompt)
146
  except Exception as e:
147
  raise gr.Error(e)
148
- return result
 
 
149
 
150
  description = """<strong>Overview:</strong> Context-aware multimodal reasoning application that demonstrates a <strong>large language model (LLM)</strong> with
151
  <strong>retrieval augmented generation (RAG)</strong>.
 
30
 
31
  config = {
32
  "model": "gpt-4",
 
33
  "temperature": 0,
34
  }
35
 
 
105
  def llm_chain(llm, prompt):
106
  llm_chain = LLMChain(llm = llm, prompt = LLM_CHAIN_PROMPT)
107
  completion = llm_chain.run({"question": prompt})
 
108
  return completion
109
 
110
  def rag_chain(llm, prompt, db):
 
113
  retriever = db.as_retriever(search_kwargs = {"k": 3}),
114
  return_source_documents = True)
115
  completion = rag_chain({"query": prompt})
116
+ return completion["result"]
 
 
117
 
118
  def invoke(openai_api_key, rag_option, prompt):
119
  if (openai_api_key == ""):
 
122
  raise gr.Error("Retrieval Augmented Generation is required.")
123
  if (prompt == ""):
124
  raise gr.Error("Prompt is required.")
 
125
  try:
126
  llm = ChatOpenAI(model_name = config.model,
127
  openai_api_key = openai_api_key,
 
130
  #splits = document_loading_splitting()
131
  #document_storage_chroma(splits)
132
  db = document_retrieval_chroma(llm, prompt)
133
+ completion = rag_chain(llm, prompt, db)
134
  elif (rag_option == "MongoDB"):
135
  #splits = document_loading_splitting()
136
  #document_storage_mongodb(splits)
137
  db = document_retrieval_mongodb(llm, prompt)
138
+ completion = rag_chain(llm, prompt, db)
139
  else:
140
  result = llm_chain(llm, prompt)
141
  except Exception as e:
142
  raise gr.Error(e)
143
+ wandb.config[rag_option] = rag_option
144
+ wandb.log({"prompt": prompt, "completion": completion})
145
+ return completion
146
 
147
  description = """<strong>Overview:</strong> Context-aware multimodal reasoning application that demonstrates a <strong>large language model (LLM)</strong> with
148
  <strong>retrieval augmented generation (RAG)</strong>.