bstraehle commited on
Commit
2301c17
·
1 Parent(s): ed33d82

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -18
app.py CHANGED
@@ -30,8 +30,8 @@ RAG_CHAIN_PROMPT = PromptTemplate(input_variables = ["context", "question"],
30
  CHROMA_DIR = "/data/chroma"
31
  YOUTUBE_DIR = "/data/youtube"
32
 
33
- #YOUTUBE_URL = "https://www.youtube.com/watch?v=--khbXchTeE"
34
- YOUTUBE_URL = "https://www.youtube.com/watch?v=Iy1IpvcJH7I&list=PL2yQDdvlhXf9XsB2W76_seM6dJxcE2Pdc&index=2"
35
 
36
  MODEL_NAME = "gpt-4"
37
 
@@ -40,26 +40,24 @@ def invoke(openai_api_key, use_rag, prompt):
40
  openai_api_key = openai_api_key,
41
  temperature = 0)
42
  if (use_rag):
43
- # if (os.path.isdir(CHROMA_DIR)):
44
- # vector_db = Chroma(embedding_function = OpenAIEmbeddings(),
45
- # persist_directory = CHROMA_DIR)
46
- # print("Load DB")
47
- # else:
48
- loader = GenericLoader(YoutubeAudioLoader([YOUTUBE_URL], YOUTUBE_DIR),
49
- OpenAIWhisperParser())
50
- docs = loader.load()
51
- text_splitter = RecursiveCharacterTextSplitter(chunk_overlap = 150,
52
- chunk_size = 1500)
53
- splits = text_splitter.split_documents(docs)
54
- vector_db = Chroma.from_documents(documents = splits,
55
- embedding = OpenAIEmbeddings(),
56
- persist_directory = CHROMA_DIR)
57
- print("Make DB")
58
  rag_chain = RetrievalQA.from_chain_type(llm,
59
  chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
60
  retriever = vector_db.as_retriever(search_kwargs = {"k": 3}),
61
  return_source_documents = True)
62
- print(os.listdir("/data/chroma/"))
63
  result = rag_chain({"query": prompt})
64
  result = result["result"]
65
  else:
 
30
  CHROMA_DIR = "/data/chroma"
31
  YOUTUBE_DIR = "/data/youtube"
32
 
33
+ YOUTUBE_URL_01 = "https://www.youtube.com/watch?v=--khbXchTeE"
34
+ YOUTUBE_URL_02 = "https://www.youtube.com/watch?v=Iy1IpvcJH7I&list=PL2yQDdvlhXf9XsB2W76_seM6dJxcE2Pdc&index=2"
35
 
36
  MODEL_NAME = "gpt-4"
37
 
 
40
  openai_api_key = openai_api_key,
41
  temperature = 0)
42
  if (use_rag):
43
+ # Document loading, splitting, and storage
44
+ loader = GenericLoader(YoutubeAudioLoader([YOUTUBE_URL_01,
45
+ YOUTUBE_URL_02], YOUTUBE_DIR),
46
+ OpenAIWhisperParser())
47
+ docs = loader.load()
48
+ text_splitter = RecursiveCharacterTextSplitter(chunk_overlap = 150,
49
+ chunk_size = 1500)
50
+ splits = text_splitter.split_documents(docs)
51
+ vector_db = Chroma.from_documents(documents = splits,
52
+ embedding = OpenAIEmbeddings(),
53
+ persist_directory = CHROMA_DIR)
54
+ # Document retrieval
55
+ #vector_db = Chroma(embedding_function = OpenAIEmbeddings(),
56
+ # persist_directory = CHROMA_DIR)
 
57
  rag_chain = RetrievalQA.from_chain_type(llm,
58
  chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
59
  retriever = vector_db.as_retriever(search_kwargs = {"k": 3}),
60
  return_source_documents = True)
 
61
  result = rag_chain({"query": prompt})
62
  result = result["result"]
63
  else: