Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -40,7 +40,7 @@ YOUTUBE_URL_3 = "https://www.youtube.com/watch?v=vw-KWfKwvTQ"
|
|
40 |
|
41 |
MODEL_NAME = "gpt-4"
|
42 |
|
43 |
-
def document_retrieval_chroma():
|
44 |
vector_db = Chroma(embedding_function = OpenAIEmbeddings(),
|
45 |
persist_directory = CHROMA_DIR)
|
46 |
rag_chain = RetrievalQA.from_chain_type(llm,
|
@@ -50,7 +50,7 @@ def document_retrieval_chroma():
|
|
50 |
result = rag_chain({"query": prompt})
|
51 |
return result["result"]
|
52 |
|
53 |
-
def invoke(openai_api_key,
|
54 |
if (openai_api_key == ""):
|
55 |
raise gr.Error("OpenAI API Key is required.")
|
56 |
if (use_rag is None):
|
@@ -61,7 +61,7 @@ def invoke(openai_api_key, use_rag, rag_db, prompt):
|
|
61 |
llm = ChatOpenAI(model_name = MODEL_NAME,
|
62 |
openai_api_key = openai_api_key,
|
63 |
temperature = 0)
|
64 |
-
if (
|
65 |
# Document loading
|
66 |
#docs = []
|
67 |
# Load PDF
|
@@ -93,7 +93,7 @@ def invoke(openai_api_key, use_rag, rag_db, prompt):
|
|
93 |
## return_source_documents = True)
|
94 |
##result = rag_chain({"query": prompt})
|
95 |
##result = result["result"]
|
96 |
-
result = document_retrieval_chroma()
|
97 |
else:
|
98 |
chain = LLMChain(llm = llm, prompt = LLM_CHAIN_PROMPT)
|
99 |
result = chain.run({"question": prompt})
|
@@ -120,8 +120,7 @@ description = """<strong>Overview:</strong> Reasoning application that demonstra
|
|
120 |
gr.close_all()
|
121 |
demo = gr.Interface(fn=invoke,
|
122 |
inputs = [gr.Textbox(label = "OpenAI API Key", value = "sk-", lines = 1),
|
123 |
-
gr.Radio([
|
124 |
-
gr.Radio(["Chroma", "MongoDB"], label="RAG Database", value = "Chroma"),
|
125 |
gr.Textbox(label = "Prompt", value = "What is GPT-4?", lines = 1)],
|
126 |
outputs = [gr.Textbox(label = "Completion", lines = 1)],
|
127 |
title = "Generative AI - LLM & RAG",
|
|
|
40 |
|
41 |
MODEL_NAME = "gpt-4"
|
42 |
|
43 |
+
def document_retrieval_chroma(llm):
|
44 |
vector_db = Chroma(embedding_function = OpenAIEmbeddings(),
|
45 |
persist_directory = CHROMA_DIR)
|
46 |
rag_chain = RetrievalQA.from_chain_type(llm,
|
|
|
50 |
result = rag_chain({"query": prompt})
|
51 |
return result["result"]
|
52 |
|
53 |
+
def invoke(openai_api_key, rag, prompt):
|
54 |
if (openai_api_key == ""):
|
55 |
raise gr.Error("OpenAI API Key is required.")
|
56 |
if (use_rag is None):
|
|
|
61 |
llm = ChatOpenAI(model_name = MODEL_NAME,
|
62 |
openai_api_key = openai_api_key,
|
63 |
temperature = 0)
|
64 |
+
if (rag != "None"):
|
65 |
# Document loading
|
66 |
#docs = []
|
67 |
# Load PDF
|
|
|
93 |
## return_source_documents = True)
|
94 |
##result = rag_chain({"query": prompt})
|
95 |
##result = result["result"]
|
96 |
+
result = document_retrieval_chroma(llm)
|
97 |
else:
|
98 |
chain = LLMChain(llm = llm, prompt = LLM_CHAIN_PROMPT)
|
99 |
result = chain.run({"question": prompt})
|
|
|
120 |
gr.close_all()
|
121 |
demo = gr.Interface(fn=invoke,
|
122 |
inputs = [gr.Textbox(label = "OpenAI API Key", value = "sk-", lines = 1),
|
123 |
+
gr.Radio(["None", "Chroma", "MongoDB"], label="Retrieval Augmented Generation", value = "None"),
|
|
|
124 |
gr.Textbox(label = "Prompt", value = "What is GPT-4?", lines = 1)],
|
125 |
outputs = [gr.Textbox(label = "Completion", lines = 1)],
|
126 |
title = "Generative AI - LLM & RAG",
|