arjunanand13 commited on
Commit
71dae17
·
verified ·
1 Parent(s): 2900f61

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -76
app.py CHANGED
@@ -1,28 +1,26 @@
1
  import os
2
  import multiprocessing
3
  import concurrent.futures
4
- # from langchain.document_loaders import TextLoader, DirectoryLoader
5
  from langchain_community.document_loaders import TextLoader, DirectoryLoader
6
  from langchain.text_splitter import RecursiveCharacterTextSplitter
7
- # from langchain.vectorstores import FAISS
8
  from langchain_community.vectorstores import FAISS
9
  from sentence_transformers import SentenceTransformer
10
  import faiss
11
- import torch
12
  import numpy as np
13
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
14
  from datetime import datetime
15
  import json
16
  import gradio as gr
17
  import re
18
  from threading import Thread
19
-
 
20
  class MultiAgentRAG:
21
- def __init__(self, embedding_model_name, lm_model_id, data_folder):
22
  self.all_splits = self.load_documents(data_folder)
23
  self.embeddings = SentenceTransformer(embedding_model_name)
24
  self.gpu_index = self.create_faiss_index()
25
- self.tokenizer, self.model = self.initialize_llm(lm_model_id)
 
26
 
27
  def load_documents(self, folder_path):
28
  loader = DirectoryLoader(folder_path, loader_cls=TextLoader)
@@ -31,7 +29,7 @@ class MultiAgentRAG:
31
  all_splits = text_splitter.split_documents(documents)
32
  print('Length of documents:', len(documents))
33
  print("LEN of all_splits", len(all_splits))
34
- for i in range(3):
35
  print(all_splits[i].page_content)
36
  return all_splits
37
 
@@ -44,47 +42,20 @@ class MultiAgentRAG:
44
  gpu_index = faiss.index_cpu_to_gpu(gpu_resource, 0, index)
45
  return gpu_index
46
 
47
- def initialize_llm(self, model_id):
48
- quantization_config = BitsAndBytesConfig(
49
- load_in_4bit=True,
50
- bnb_4bit_use_double_quant=True,
51
- bnb_4bit_quant_type="nf4",
52
- bnb_4bit_compute_dtype=torch.bfloat16
53
- )
54
- tokenizer = AutoTokenizer.from_pretrained(model_id)
55
- model = AutoModelForCausalLM.from_pretrained(
56
- model_id,
57
- torch_dtype=torch.bfloat16,
58
- device_map="auto",
59
- quantization_config=quantization_config
60
- )
61
- return tokenizer, model
62
-
63
- def generate_response_with_timeout(self, input_ids, max_new_tokens=1000):
64
  try:
65
- streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
66
- generate_kwargs = dict(
67
- input_ids=input_ids,
68
- max_new_tokens=max_new_tokens,
69
- do_sample=True,
70
- top_p=1.0,
71
- top_k=20,
72
  temperature=0.8,
73
- repetition_penalty=1.2,
74
- eos_token_id=[128001, 128008, 128009],
75
- streamer=streamer,
76
  )
77
-
78
- thread = Thread(target=self.model.generate, kwargs=generate_kwargs)
79
- thread.start()
80
-
81
- generated_text = ""
82
- for new_text in streamer:
83
- generated_text += new_text
84
-
85
- return generated_text
86
  except Exception as e:
87
- print(f"Error in generate_response_with_timeout: {str(e)}")
88
  return "Text generation process encountered an error"
89
 
90
  def retrieval_agent(self, query):
@@ -97,44 +68,51 @@ class MultiAgentRAG:
97
  return content
98
 
99
  def grading_agent(self, query, retrieved_content):
100
- grading_prompt = f"""
101
- Evaluate the relevance of the following retrieved content to the given query:
102
-
103
- Query: {query}
104
-
105
- Retrieved Content:
106
- {retrieved_content}
 
 
 
 
 
 
107
 
108
- Rate the relevance on a scale of 1-10 and explain your rating:
109
- """
110
- input_ids = self.tokenizer.encode(grading_prompt, return_tensors="pt").to(self.model.device)
111
- grading_response = self.generate_response_with_timeout(input_ids)
112
 
113
  # Extract the numerical rating from the response
114
- rating = int(re.search(r'\d+', grading_response).group())
 
115
  return rating, grading_response
116
 
117
  def query_rewrite_agent(self, original_query):
118
- rewrite_prompt = f"""
119
- The following query did not yield relevant results. Please rewrite it to potentially improve retrieval:
120
-
121
- Original Query: {original_query}
 
 
 
 
 
 
122
 
123
- Rewritten Query:
124
- """
125
- input_ids = self.tokenizer.encode(rewrite_prompt, return_tensors="pt").to(self.model.device)
126
- rewritten_query = self.generate_response_with_timeout(input_ids)
127
  return rewritten_query.strip()
128
 
129
  def generation_agent(self, query, retrieved_content):
130
- conversation = [
131
  {"role": "system", "content": "You are a knowledgeable assistant with access to a comprehensive database."},
132
  {"role": "user", "content": f"""
133
  I need you to answer my question and provide related information in a specific format.
134
  I have provided five relatable json files {retrieved_content}, choose the most suitable chunks for answering the query.
135
  RETURN ONLY SOLUTION without additional comments, sign-offs, retrived chunks, refrence to any Ticket or extra phrases. Be direct and to the point.
136
  IF THERE IS NO ANSWER RELATABLE IN RETRIEVED CHUNKS, RETURN "NO SOLUTION AVAILABLE".
137
- DO NOT GIVE REFRENCE TO ANY CHUNKS OR TICKETS,BE ON POINT.
138
 
139
  Here's my question:
140
  Query: {query}
@@ -142,24 +120,19 @@ class MultiAgentRAG:
142
  """}
143
  ]
144
 
145
- input_ids = self.tokenizer.encode(self.tokenizer.apply_chat_template(conversation, tokenize=False), return_tensors="pt").to(self.model.device)
146
- return self.generate_response_with_timeout(input_ids)
147
 
148
  def run_multi_agent_rag(self, query):
149
  max_iterations = 3
150
  for i in range(max_iterations):
151
- # Retrieval step
152
  retrieved_content = self.retrieval_agent(query)
153
 
154
- # Grading step
155
  relevance_score, grading_explanation = self.grading_agent(query, retrieved_content)
156
 
157
- if relevance_score >= 7: # Assuming 7 out of 10 is the threshold for relevance
158
- # Generation step
159
  answer = self.generation_agent(query, retrieved_content)
160
  return answer, retrieved_content, grading_explanation
161
  else:
162
- # Query rewrite step
163
  query = self.query_rewrite_agent(query)
164
 
165
  return "Unable to find a relevant answer after multiple attempts.", "", "Low relevance across all attempts."
@@ -204,8 +177,8 @@ def launch_interface(doc_retrieval_gen):
204
 
205
  if __name__ == "__main__":
206
  embedding_model_name = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L12'
207
- lm_model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
208
  data_folder = 'sample_embedding_folder2'
209
 
210
- multi_agent_rag = MultiAgentRAG(embedding_model_name, lm_model_id, data_folder)
211
  launch_interface(multi_agent_rag)
 
1
  import os
2
  import multiprocessing
3
  import concurrent.futures
 
4
  from langchain_community.document_loaders import TextLoader, DirectoryLoader
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
6
  from langchain_community.vectorstores import FAISS
7
  from sentence_transformers import SentenceTransformer
8
  import faiss
 
9
  import numpy as np
 
10
  from datetime import datetime
11
  import json
12
  import gradio as gr
13
  import re
14
  from threading import Thread
15
+ from openai import OpenAI
16
+
17
  class MultiAgentRAG:
18
+ def __init__(self, embedding_model_name, openai_model_id, data_folder, api_key=None):
19
  self.all_splits = self.load_documents(data_folder)
20
  self.embeddings = SentenceTransformer(embedding_model_name)
21
  self.gpu_index = self.create_faiss_index()
22
+ self.openai_client = OpenAI(api_key=api_key or os.environ.get("OPENAI_API_KEY"))
23
+ self.openai_model_id = openai_model_id
24
 
25
  def load_documents(self, folder_path):
26
  loader = DirectoryLoader(folder_path, loader_cls=TextLoader)
 
29
  all_splits = text_splitter.split_documents(documents)
30
  print('Length of documents:', len(documents))
31
  print("LEN of all_splits", len(all_splits))
32
+ for i in range(min(3, len(all_splits))):
33
  print(all_splits[i].page_content)
34
  return all_splits
35
 
 
42
  gpu_index = faiss.index_cpu_to_gpu(gpu_resource, 0, index)
43
  return gpu_index
44
 
45
+ def generate_openai_response(self, messages, max_tokens=1000):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  try:
47
+ response = self.openai_client.chat.completions.create(
48
+ model=self.openai_model_id,
49
+ messages=messages,
50
+ max_tokens=max_tokens,
 
 
 
51
  temperature=0.8,
52
+ top_p=1.0,
53
+ frequency_penalty=0,
54
+ presence_penalty=0
55
  )
56
+ return response.choices[0].message.content
 
 
 
 
 
 
 
 
57
  except Exception as e:
58
+ print(f"Error in generate_openai_response: {str(e)}")
59
  return "Text generation process encountered an error"
60
 
61
  def retrieval_agent(self, query):
 
68
  return content
69
 
70
  def grading_agent(self, query, retrieved_content):
71
+ messages = [
72
+ {"role": "system", "content": "You are an expert at evaluating the relevance of retrieved content to a query."},
73
+ {"role": "user", "content": f"""
74
+ Evaluate the relevance of the following retrieved content to the given query:
75
+
76
+ Query: {query}
77
+
78
+ Retrieved Content:
79
+ {retrieved_content}
80
+
81
+ Rate the relevance on a scale of 1-10 and explain your rating:
82
+ """}
83
+ ]
84
 
85
+ grading_response = self.generate_openai_response(messages)
 
 
 
86
 
87
  # Extract the numerical rating from the response
88
+ match = re.search(r'\b([1-9]|10)\b', grading_response)
89
+ rating = int(match.group()) if match else 5 # Default to 5 if no rating found
90
  return rating, grading_response
91
 
92
  def query_rewrite_agent(self, original_query):
93
+ messages = [
94
+ {"role": "system", "content": "You are an expert at rewriting queries to improve information retrieval results."},
95
+ {"role": "user", "content": f"""
96
+ The following query did not yield relevant results. Please rewrite it to potentially improve retrieval:
97
+
98
+ Original Query: {original_query}
99
+
100
+ Rewritten Query:
101
+ """}
102
+ ]
103
 
104
+ rewritten_query = self.generate_openai_response(messages)
 
 
 
105
  return rewritten_query.strip()
106
 
107
  def generation_agent(self, query, retrieved_content):
108
+ messages = [
109
  {"role": "system", "content": "You are a knowledgeable assistant with access to a comprehensive database."},
110
  {"role": "user", "content": f"""
111
  I need you to answer my question and provide related information in a specific format.
112
  I have provided five relatable json files {retrieved_content}, choose the most suitable chunks for answering the query.
113
  RETURN ONLY SOLUTION without additional comments, sign-offs, retrived chunks, refrence to any Ticket or extra phrases. Be direct and to the point.
114
  IF THERE IS NO ANSWER RELATABLE IN RETRIEVED CHUNKS, RETURN "NO SOLUTION AVAILABLE".
115
+ DO NOT GIVE REFRENCE TO ANY CHUNKS OR TICKETS, BE ON POINT.
116
 
117
  Here's my question:
118
  Query: {query}
 
120
  """}
121
  ]
122
 
123
+ return self.generate_openai_response(messages)
 
124
 
125
  def run_multi_agent_rag(self, query):
126
  max_iterations = 3
127
  for i in range(max_iterations):
 
128
  retrieved_content = self.retrieval_agent(query)
129
 
 
130
  relevance_score, grading_explanation = self.grading_agent(query, retrieved_content)
131
 
132
+ if relevance_score >= 7:
 
133
  answer = self.generation_agent(query, retrieved_content)
134
  return answer, retrieved_content, grading_explanation
135
  else:
 
136
  query = self.query_rewrite_agent(query)
137
 
138
  return "Unable to find a relevant answer after multiple attempts.", "", "Low relevance across all attempts."
 
177
 
178
  if __name__ == "__main__":
179
  embedding_model_name = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L12'
180
+ openai_model_id = "gpt-4-turbo"
181
  data_folder = 'sample_embedding_folder2'
182
 
183
+ multi_agent_rag = MultiAgentRAG(embedding_model_name, openai_model_id, data_folder)
184
  launch_interface(multi_agent_rag)