bstraehle commited on
Commit
996e450
·
1 Parent(s): 5397e21

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -11
app.py CHANGED
@@ -30,23 +30,26 @@ YOUTUBE_URL = "https://www.youtube.com/watch?v=--khbXchTeE"
30
  MODEL_NAME = "gpt-4"
31
 
32
  def invoke(openai_api_key, use_rag, prompt):
33
- if (os.path.isdir(CHROMA_DIR)):
34
- shutil.rmtree(CHROMA_DIR)
35
- if (os.path.isdir(YOUTUBE_DIR)):
36
- shutil.rmtree(YOUTUBE_DIR)
37
  if (use_rag):
38
- loader = GenericLoader(YoutubeAudioLoader([YOUTUBE_URL], YOUTUBE_DIR), OpenAIWhisperParser())
39
- docs = loader.load()
40
- text_splitter = RecursiveCharacterTextSplitter(chunk_size = 1500, chunk_overlap = 150)
41
- splits = text_splitter.split_documents(docs)
42
- vector_db = Chroma.from_documents(documents = splits, embedding = OpenAIEmbeddings(), persist_directory = CHROMA_DIR)
 
 
 
43
  llm = ChatOpenAI(model_name = MODEL_NAME, openai_api_key = openai_api_key, temperature = 0)
44
  qa_chain = RetrievalQA.from_chain_type(llm, retriever = vector_db.as_retriever(search_kwargs = {"k": 3}), return_source_documents = True, chain_type_kwargs = {"prompt": QA_CHAIN_PROMPT})
 
45
  else:
46
- #vector_db = Chroma(persist_directory = CHROMA_DIR, embedding_function = OpenAIEmbeddings())
47
  llm = ChatOpenAI(model_name = MODEL_NAME, openai_api_key = openai_api_key, temperature = 0)
48
  qa_chain = RetrievalQA.from_chain_type(llm, retriever = None, return_source_documents = True, cchain_type_kwargs = {"prompt": QA_CHAIN_PROMPT})
49
- result = qa_chain({"query": prompt})
50
  #print(result)
51
  return result["result"]
52
 
 
30
  MODEL_NAME = "gpt-4"
31
 
32
  def invoke(openai_api_key, use_rag, prompt):
33
+ # if (os.path.isdir(CHROMA_DIR)):
34
+ # shutil.rmtree(CHROMA_DIR)
35
+ # if (os.path.isdir(YOUTUBE_DIR)):
36
+ # shutil.rmtree(YOUTUBE_DIR)
37
  if (use_rag):
38
+ if (os.path.isdir(CHROMA_DIR)):
39
+ vector_db = Chroma(persist_directory = CHROMA_DIR, embedding_function = OpenAIEmbeddings())
40
+ else:
41
+ loader = GenericLoader(YoutubeAudioLoader([YOUTUBE_URL], YOUTUBE_DIR), OpenAIWhisperParser())
42
+ docs = loader.load()
43
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size = 1500, chunk_overlap = 150)
44
+ splits = text_splitter.split_documents(docs)
45
+ vector_db = Chroma.from_documents(documents = splits, embedding = OpenAIEmbeddings(), persist_directory = CHROMA_DIR)
46
  llm = ChatOpenAI(model_name = MODEL_NAME, openai_api_key = openai_api_key, temperature = 0)
47
  qa_chain = RetrievalQA.from_chain_type(llm, retriever = vector_db.as_retriever(search_kwargs = {"k": 3}), return_source_documents = True, chain_type_kwargs = {"prompt": QA_CHAIN_PROMPT})
48
+ result = qa_chain({"query": prompt})
49
  else:
 
50
  llm = ChatOpenAI(model_name = MODEL_NAME, openai_api_key = openai_api_key, temperature = 0)
51
  qa_chain = RetrievalQA.from_chain_type(llm, retriever = None, return_source_documents = True, cchain_type_kwargs = {"prompt": QA_CHAIN_PROMPT})
52
+ result = qa_chain({"query": prompt})
53
  #print(result)
54
  return result["result"]
55