bstraehle commited on
Commit
e41fd61
·
1 Parent(s): 5b9fc25

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +43 -28
rag.py CHANGED
@@ -31,8 +31,12 @@ MONGODB_DB_NAME = "langchain_db"
31
  MONGODB_COLLECTION_NAME = "gpt-4"
32
  MONGODB_INDEX_NAME = "default"
33
 
34
- LLM_CHAIN_PROMPT = PromptTemplate(input_variables = ["question"], template = os.environ["LLM_TEMPLATE"])
35
- RAG_CHAIN_PROMPT = PromptTemplate(input_variables = ["context", "question"], template = os.environ["RAG_TEMPLATE"])
 
 
 
 
36
 
37
  client = MongoClient(MONGODB_ATLAS_CLUSTER_URI)
38
  collection = client[MONGODB_DB_NAME][MONGODB_COLLECTION_NAME]
@@ -49,28 +53,34 @@ def load_documents():
49
  docs.extend(loader.load())
50
 
51
  # YouTube
52
- loader = GenericLoader(YoutubeAudioLoader([YOUTUBE_URL_1, YOUTUBE_URL_2], YOUTUBE_DIR),
53
- OpenAIWhisperParser())
 
 
 
54
  docs.extend(loader.load())
55
 
56
  return docs
57
 
58
  def split_documents(config, docs):
59
- text_splitter = RecursiveCharacterTextSplitter(chunk_overlap = config["chunk_overlap"],
60
- chunk_size = config["chunk_size"])
 
61
 
62
  return text_splitter.split_documents(docs)
63
 
64
  def store_chroma(chunks):
65
- Chroma.from_documents(documents = chunks,
66
- embedding = OpenAIEmbeddings(disallowed_special = ()),
67
- persist_directory = CHROMA_DIR)
 
68
 
69
  def store_mongodb(chunks):
70
- MongoDBAtlasVectorSearch.from_documents(documents = chunks,
71
- embedding = OpenAIEmbeddings(disallowed_special = ()),
72
- collection = collection,
73
- index_name = MONGODB_INDEX_NAME)
 
74
 
75
  def rag_ingestion(config):
76
  docs = load_documents()
@@ -81,22 +91,26 @@ def rag_ingestion(config):
81
  store_mongodb(chunks)
82
 
83
  def retrieve_chroma():
84
- return Chroma(embedding_function = OpenAIEmbeddings(disallowed_special = ()),
85
- persist_directory = CHROMA_DIR)
 
86
 
87
  def retrieve_mongodb():
88
- return MongoDBAtlasVectorSearch.from_connection_string(MONGODB_ATLAS_CLUSTER_URI,
89
- MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
90
- OpenAIEmbeddings(disallowed_special = ()),
91
- index_name = MONGODB_INDEX_NAME)
 
92
 
93
  def get_llm(config):
94
- return ChatOpenAI(model_name = config["model_name"],
95
- temperature = config["temperature"])
 
96
 
97
  def llm_chain(config, prompt):
98
- llm_chain = LLMChain(llm = get_llm(config),
99
- prompt = LLM_CHAIN_PROMPT)
 
100
 
101
  with get_openai_callback() as cb:
102
  completion = llm_chain.generate([{"question": prompt}])
@@ -111,11 +125,12 @@ def rag_chain(config, rag_option, prompt):
111
  elif (rag_option == RAG_MONGODB):
112
  db = retrieve_mongodb()
113
 
114
- rag_chain = RetrievalQA.from_chain_type(llm,
115
- chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT,
116
- "verbose": True},
117
- retriever = db.as_retriever(search_kwargs = {"k": config["k"]}),
118
- return_source_documents = True)
 
119
 
120
  with get_openai_callback() as cb:
121
  completion = rag_chain({"query": prompt})
 
31
  MONGODB_COLLECTION_NAME = "gpt-4"
32
  MONGODB_INDEX_NAME = "default"
33
 
34
+ LLM_CHAIN_PROMPT = PromptTemplate(
35
+ input_variables = ["question"],
36
+ template = os.environ["LLM_TEMPLATE"])
37
+ RAG_CHAIN_PROMPT = PromptTemplate(
38
+ input_variables = ["context", "question"],
39
+ template = os.environ["RAG_TEMPLATE"])
40
 
41
  client = MongoClient(MONGODB_ATLAS_CLUSTER_URI)
42
  collection = client[MONGODB_DB_NAME][MONGODB_COLLECTION_NAME]
 
53
  docs.extend(loader.load())
54
 
55
  # YouTube
56
+ loader = GenericLoader(
57
+ YoutubeAudioLoader(
58
+ [YOUTUBE_URL_1, YOUTUBE_URL_2],
59
+ YOUTUBE_DIR),
60
+ OpenAIWhisperParser())
61
  docs.extend(loader.load())
62
 
63
  return docs
64
 
65
  def split_documents(config, docs):
66
+ text_splitter = RecursiveCharacterTextSplitter(
67
+ chunk_overlap = config["chunk_overlap"],
68
+ chunk_size = config["chunk_size"])
69
 
70
  return text_splitter.split_documents(docs)
71
 
72
  def store_chroma(chunks):
73
+ Chroma.from_documents(
74
+ documents = chunks,
75
+ embedding = OpenAIEmbeddings(disallowed_special = ()),
76
+ persist_directory = CHROMA_DIR)
77
 
78
  def store_mongodb(chunks):
79
+ MongoDBAtlasVectorSearch.from_documents(
80
+ documents = chunks,
81
+ embedding = OpenAIEmbeddings(disallowed_special = ()),
82
+ collection = collection,
83
+ index_name = MONGODB_INDEX_NAME)
84
 
85
  def rag_ingestion(config):
86
  docs = load_documents()
 
91
  store_mongodb(chunks)
92
 
93
  def retrieve_chroma():
94
+ return Chroma(
95
+ embedding_function = OpenAIEmbeddings(disallowed_special = ()),
96
+ persist_directory = CHROMA_DIR)
97
 
98
  def retrieve_mongodb():
99
+ return MongoDBAtlasVectorSearch.from_connection_string(
100
+ MONGODB_ATLAS_CLUSTER_URI,
101
+ MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
102
+ OpenAIEmbeddings(disallowed_special = ()),
103
+ index_name = MONGODB_INDEX_NAME)
104
 
105
  def get_llm(config):
106
+ return ChatOpenAI(
107
+ model_name = config["model_name"],
108
+ temperature = config["temperature"])
109
 
110
  def llm_chain(config, prompt):
111
+ llm_chain = LLMChain(
112
+ llm = get_llm(config),
113
+ prompt = LLM_CHAIN_PROMPT)
114
 
115
  with get_openai_callback() as cb:
116
  completion = llm_chain.generate([{"question": prompt}])
 
125
  elif (rag_option == RAG_MONGODB):
126
  db = retrieve_mongodb()
127
 
128
+ rag_chain = RetrievalQA.from_chain_type(
129
+ llm,
130
+ chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT,
131
+ "verbose": True},
132
+ retriever = db.as_retriever(search_kwargs = {"k": config["k"]}),
133
+ return_source_documents = True)
134
 
135
  with get_openai_callback() as cb:
136
  completion = rag_chain({"query": prompt})