bstraehle commited on
Commit
2704db0
·
1 Parent(s): 9549818

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -40,7 +40,7 @@ YOUTUBE_URL_3 = "https://www.youtube.com/watch?v=vw-KWfKwvTQ"
40
 
41
  MODEL_NAME = "gpt-4"
42
 
43
- def document_retrieval_chroma():
44
  vector_db = Chroma(embedding_function = OpenAIEmbeddings(),
45
  persist_directory = CHROMA_DIR)
46
  rag_chain = RetrievalQA.from_chain_type(llm,
@@ -50,7 +50,7 @@ def document_retrieval_chroma():
50
  result = rag_chain({"query": prompt})
51
  return result["result"]
52
 
53
- def invoke(openai_api_key, use_rag, rag_db, prompt):
54
  if (openai_api_key == ""):
55
  raise gr.Error("OpenAI API Key is required.")
56
  if (use_rag is None):
@@ -61,7 +61,7 @@ def invoke(openai_api_key, use_rag, rag_db, prompt):
61
  llm = ChatOpenAI(model_name = MODEL_NAME,
62
  openai_api_key = openai_api_key,
63
  temperature = 0)
64
- if (use_rag):
65
  # Document loading
66
  #docs = []
67
  # Load PDF
@@ -93,7 +93,7 @@ def invoke(openai_api_key, use_rag, rag_db, prompt):
93
  ## return_source_documents = True)
94
  ##result = rag_chain({"query": prompt})
95
  ##result = result["result"]
96
- result = document_retrieval_chroma()
97
  else:
98
  chain = LLMChain(llm = llm, prompt = LLM_CHAIN_PROMPT)
99
  result = chain.run({"question": prompt})
@@ -120,8 +120,7 @@ description = """<strong>Overview:</strong> Reasoning application that demonstra
120
  gr.close_all()
121
  demo = gr.Interface(fn=invoke,
122
  inputs = [gr.Textbox(label = "OpenAI API Key", value = "sk-", lines = 1),
123
- gr.Radio([True, False], label="Retrieval Augmented Generation (RAG)", value = False),
124
- gr.Radio(["Chroma", "MongoDB"], label="RAG Database", value = "Chroma"),
125
  gr.Textbox(label = "Prompt", value = "What is GPT-4?", lines = 1)],
126
  outputs = [gr.Textbox(label = "Completion", lines = 1)],
127
  title = "Generative AI - LLM & RAG",
 
40
 
41
  MODEL_NAME = "gpt-4"
42
 
43
+ def document_retrieval_chroma(llm):
44
  vector_db = Chroma(embedding_function = OpenAIEmbeddings(),
45
  persist_directory = CHROMA_DIR)
46
  rag_chain = RetrievalQA.from_chain_type(llm,
 
50
  result = rag_chain({"query": prompt})
51
  return result["result"]
52
 
53
+ def invoke(openai_api_key, rag, prompt):
54
  if (openai_api_key == ""):
55
  raise gr.Error("OpenAI API Key is required.")
56
  if (use_rag is None):
 
61
  llm = ChatOpenAI(model_name = MODEL_NAME,
62
  openai_api_key = openai_api_key,
63
  temperature = 0)
64
+ if (rag != "None"):
65
  # Document loading
66
  #docs = []
67
  # Load PDF
 
93
  ## return_source_documents = True)
94
  ##result = rag_chain({"query": prompt})
95
  ##result = result["result"]
96
+ result = document_retrieval_chroma(llm)
97
  else:
98
  chain = LLMChain(llm = llm, prompt = LLM_CHAIN_PROMPT)
99
  result = chain.run({"question": prompt})
 
120
  gr.close_all()
121
  demo = gr.Interface(fn=invoke,
122
  inputs = [gr.Textbox(label = "OpenAI API Key", value = "sk-", lines = 1),
123
+ gr.Radio(["None", "Chroma", "MongoDB"], label="Retrieval Augmented Generation", value = "None"),
 
124
  gr.Textbox(label = "Prompt", value = "What is GPT-4?", lines = 1)],
125
  outputs = [gr.Textbox(label = "Completion", lines = 1)],
126
  title = "Generative AI - LLM & RAG",