bstraehle commited on
Commit
86d2f65
·
1 Parent(s): 7b3bd25

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -28
app.py CHANGED
@@ -40,6 +40,54 @@ YOUTUBE_URL_3 = "https://www.youtube.com/watch?v=vw-KWfKwvTQ"
40
 
41
  MODEL_NAME = "gpt-4"
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def document_retrieval_chroma(llm, prompt):
44
  vector_db = Chroma(embedding_function = OpenAIEmbeddings(),
45
  persist_directory = CHROMA_DIR)
@@ -60,10 +108,10 @@ def document_retrieval_mongodb(llm, prompt):
60
  result = rag_chain({"query": prompt})
61
  return result["result"]
62
 
63
- def invoke(openai_api_key, rag, prompt):
64
  if (openai_api_key == ""):
65
  raise gr.Error("OpenAI API Key is required.")
66
- if (rag is None):
67
  raise gr.Error("Retrieval Augmented Generation is required.")
68
  if (prompt == ""):
69
  raise gr.Error("Prompt is required.")
@@ -73,33 +121,11 @@ def invoke(openai_api_key, rag, prompt):
73
  openai_api_key = openai_api_key,
74
  temperature = 0)
75
 
76
- if (rag == "Chroma"):
77
- # Document loading
78
- #docs = []
79
- # Load PDF
80
- #loader = PyPDFLoader(PDF_URL)
81
- #docs.extend(loader.load())
82
- # Load Web
83
- #loader = WebBaseLoader(WEB_URL_1)
84
- #docs.extend(loader.load())
85
- # Load YouTube
86
- #loader = GenericLoader(YoutubeAudioLoader([YOUTUBE_URL_1,
87
- # YOUTUBE_URL_2,
88
- # YOUTUBE_URL_3], YOUTUBE_DIR),
89
- # OpenAIWhisperParser())
90
- #docs.extend(loader.load())
91
- # Document splitting
92
- #text_splitter = RecursiveCharacterTextSplitter(chunk_overlap = 150,
93
- # chunk_size = 1500)
94
- #splits = text_splitter.split_documents(docs)
95
- # Document storage
96
- #vector_db = Chroma.from_documents(documents = splits,
97
- # embedding = OpenAIEmbeddings(disallowed_special = ()),
98
- # persist_directory = CHROMA_DIR)
99
- # Document retrieval
100
  result = document_retrieval_chroma(llm, prompt)
101
- elif (rag == "MongoDB"):
102
- # Document retrieval
103
  result = document_retrieval_mongodb(llm, prompt)
104
  else:
105
  chain = LLMChain(llm = llm, prompt = LLM_CHAIN_PROMPT)
 
40
 
41
  MODEL_NAME = "gpt-4"
42
 
43
+ def document_storage_chroma():
44
+ # Document loading
45
+ docs = []
46
+ # Load PDF
47
+ loader = PyPDFLoader(PDF_URL)
48
+ docs.extend(loader.load())
49
+ # Load Web
50
+ loader = WebBaseLoader(WEB_URL_1)
51
+ docs.extend(loader.load())
52
+ # Load YouTube
53
+ loader = GenericLoader(YoutubeAudioLoader([YOUTUBE_URL_1,
54
+ YOUTUBE_URL_2,
55
+ YOUTUBE_URL_3], YOUTUBE_DIR),
56
+ OpenAIWhisperParser())
57
+ docs.extend(loader.load())
58
+ # Document splitting
59
+ text_splitter = RecursiveCharacterTextSplitter(chunk_overlap = 150,
60
+ chunk_size = 1500)
61
+ splits = text_splitter.split_documents(docs)
62
+ # Document storage
63
+ vector_db = Chroma.from_documents(documents = splits,
64
+ embedding = OpenAIEmbeddings(disallowed_special = ()),
65
+ persist_directory = CHROMA_DIR)
66
+
67
+ def document_storage_mongodb():
68
+ # Document loading
69
+ docs = []
70
+ # Load PDF
71
+ loader = PyPDFLoader(PDF_URL)
72
+ docs.extend(loader.load())
73
+ # Load Web
74
+ loader = WebBaseLoader(WEB_URL_1)
75
+ docs.extend(loader.load())
76
+ # Load YouTube
77
+ loader = GenericLoader(YoutubeAudioLoader([YOUTUBE_URL_1,
78
+ YOUTUBE_URL_2,
79
+ YOUTUBE_URL_3], YOUTUBE_DIR),
80
+ OpenAIWhisperParser())
81
+ docs.extend(loader.load())
82
+ # Document splitting
83
+ text_splitter = RecursiveCharacterTextSplitter(chunk_overlap = 150,
84
+ chunk_size = 1500)
85
+ splits = text_splitter.split_documents(docs)
86
+ # Document storage
87
+ vector_db = Chroma.from_documents(documents = splits,
88
+ embedding = OpenAIEmbeddings(disallowed_special = ()),
89
+ persist_directory = CHROMA_DIR)
90
+
91
  def document_retrieval_chroma(llm, prompt):
92
  vector_db = Chroma(embedding_function = OpenAIEmbeddings(),
93
  persist_directory = CHROMA_DIR)
 
108
  result = rag_chain({"query": prompt})
109
  return result["result"]
110
 
111
+ def invoke(openai_api_key, rag_option, prompt):
112
  if (openai_api_key == ""):
113
  raise gr.Error("OpenAI API Key is required.")
114
+ if (rag_option is None):
115
  raise gr.Error("Retrieval Augmented Generation is required.")
116
  if (prompt == ""):
117
  raise gr.Error("Prompt is required.")
 
121
  openai_api_key = openai_api_key,
122
  temperature = 0)
123
 
124
+ if (rag_option == "Chroma"):
125
+ #document_storage_chroma()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  result = document_retrieval_chroma(llm, prompt)
127
+ elif (rag_option == "MongoDB"):
128
+ #document_storage_mongodb()
129
  result = document_retrieval_mongodb(llm, prompt)
130
  else:
131
  chain = LLMChain(llm = llm, prompt = LLM_CHAIN_PROMPT)