Update app.py
Browse files
app.py
CHANGED
@@ -30,8 +30,8 @@ RAG_CHAIN_PROMPT = PromptTemplate(input_variables = ["context", "question"],
|
|
30 |
CHROMA_DIR = "/data/chroma"
|
31 |
YOUTUBE_DIR = "/data/youtube"
|
32 |
|
33 |
-
|
34 |
-
|
35 |
|
36 |
MODEL_NAME = "gpt-4"
|
37 |
|
@@ -40,26 +40,24 @@ def invoke(openai_api_key, use_rag, prompt):
|
|
40 |
openai_api_key = openai_api_key,
|
41 |
temperature = 0)
|
42 |
if (use_rag):
|
43 |
-
#
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
print("Make DB")
|
58 |
rag_chain = RetrievalQA.from_chain_type(llm,
|
59 |
chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
|
60 |
retriever = vector_db.as_retriever(search_kwargs = {"k": 3}),
|
61 |
return_source_documents = True)
|
62 |
-
print(os.listdir("/data/chroma/"))
|
63 |
result = rag_chain({"query": prompt})
|
64 |
result = result["result"]
|
65 |
else:
|
|
|
30 |
CHROMA_DIR = "/data/chroma"
|
31 |
YOUTUBE_DIR = "/data/youtube"
|
32 |
|
33 |
+
YOUTUBE_URL_01 = "https://www.youtube.com/watch?v=--khbXchTeE"
|
34 |
+
YOUTUBE_URL_02 = "https://www.youtube.com/watch?v=Iy1IpvcJH7I&list=PL2yQDdvlhXf9XsB2W76_seM6dJxcE2Pdc&index=2"
|
35 |
|
36 |
MODEL_NAME = "gpt-4"
|
37 |
|
|
|
40 |
openai_api_key = openai_api_key,
|
41 |
temperature = 0)
|
42 |
if (use_rag):
|
43 |
+
# Document loading, splitting, and storage
|
44 |
+
loader = GenericLoader(YoutubeAudioLoader([YOUTUBE_URL_01,
|
45 |
+
YOUTUBE_URL_02], YOUTUBE_DIR),
|
46 |
+
OpenAIWhisperParser())
|
47 |
+
docs = loader.load()
|
48 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_overlap = 150,
|
49 |
+
chunk_size = 1500)
|
50 |
+
splits = text_splitter.split_documents(docs)
|
51 |
+
vector_db = Chroma.from_documents(documents = splits,
|
52 |
+
embedding = OpenAIEmbeddings(),
|
53 |
+
persist_directory = CHROMA_DIR)
|
54 |
+
# Document retrieval
|
55 |
+
#vector_db = Chroma(embedding_function = OpenAIEmbeddings(),
|
56 |
+
# persist_directory = CHROMA_DIR)
|
|
|
57 |
rag_chain = RetrievalQA.from_chain_type(llm,
|
58 |
chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
|
59 |
retriever = vector_db.as_retriever(search_kwargs = {"k": 3}),
|
60 |
return_source_documents = True)
|
|
|
61 |
result = rag_chain({"query": prompt})
|
62 |
result = result["result"]
|
63 |
else:
|