bstraehle commited on
Commit
04a1583
·
1 Parent(s): 996e450

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -11
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- import openai, os, shutil
3
 
4
  from langchain.chains import RetrievalQA
5
  from langchain.chat_models import ChatOpenAI
@@ -30,10 +30,7 @@ YOUTUBE_URL = "https://www.youtube.com/watch?v=--khbXchTeE"
30
  MODEL_NAME = "gpt-4"
31
 
32
  def invoke(openai_api_key, use_rag, prompt):
33
- # if (os.path.isdir(CHROMA_DIR)):
34
- # shutil.rmtree(CHROMA_DIR)
35
- # if (os.path.isdir(YOUTUBE_DIR)):
36
- # shutil.rmtree(YOUTUBE_DIR)
37
  if (use_rag):
38
  if (os.path.isdir(CHROMA_DIR)):
39
  vector_db = Chroma(persist_directory = CHROMA_DIR, embedding_function = OpenAIEmbeddings())
@@ -43,13 +40,13 @@ def invoke(openai_api_key, use_rag, prompt):
43
  text_splitter = RecursiveCharacterTextSplitter(chunk_size = 1500, chunk_overlap = 150)
44
  splits = text_splitter.split_documents(docs)
45
  vector_db = Chroma.from_documents(documents = splits, embedding = OpenAIEmbeddings(), persist_directory = CHROMA_DIR)
46
- llm = ChatOpenAI(model_name = MODEL_NAME, openai_api_key = openai_api_key, temperature = 0)
47
- qa_chain = RetrievalQA.from_chain_type(llm, retriever = vector_db.as_retriever(search_kwargs = {"k": 3}), return_source_documents = True, chain_type_kwargs = {"prompt": QA_CHAIN_PROMPT})
48
- result = qa_chain({"query": prompt})
49
  else:
50
- llm = ChatOpenAI(model_name = MODEL_NAME, openai_api_key = openai_api_key, temperature = 0)
51
- qa_chain = RetrievalQA.from_chain_type(llm, retriever = None, return_source_documents = True, cchain_type_kwargs = {"prompt": QA_CHAIN_PROMPT})
52
- result = qa_chain({"query": prompt})
 
53
  #print(result)
54
  return result["result"]
55
 
 
1
  import gradio as gr
2
+ import openai, os
3
 
4
  from langchain.chains import RetrievalQA
5
  from langchain.chat_models import ChatOpenAI
 
30
  MODEL_NAME = "gpt-4"
31
 
32
  def invoke(openai_api_key, use_rag, prompt):
33
+ llm = ChatOpenAI(model_name = MODEL_NAME, openai_api_key = openai_api_key, temperature = 0)
 
 
 
34
  if (use_rag):
35
  if (os.path.isdir(CHROMA_DIR)):
36
  vector_db = Chroma(persist_directory = CHROMA_DIR, embedding_function = OpenAIEmbeddings())
 
40
  text_splitter = RecursiveCharacterTextSplitter(chunk_size = 1500, chunk_overlap = 150)
41
  splits = text_splitter.split_documents(docs)
42
  vector_db = Chroma.from_documents(documents = splits, embedding = OpenAIEmbeddings(), persist_directory = CHROMA_DIR)
43
+ rag_chain = RetrievalQA.from_chain_type(llm, retriever = vector_db.as_retriever(search_kwargs = {"k": 3}), return_source_documents = True, chain_type_kwargs = {"prompt": QA_CHAIN_PROMPT})
44
+ result = rag_chain({"query": prompt})
 
45
  else:
46
+ #qa_chain = RetrievalQA.from_chain_type(llm, retriever = None, return_source_documents = True, cchain_type_kwargs = {"prompt": QA_CHAIN_PROMPT})
47
+ #result = qa_chain({"query": prompt})
48
+ chain = LLMChain(llm = llm)
49
+ result = chain({"query": prompt})
50
  #print(result)
51
  return result["result"]
52