Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
import os
|
|
|
|
|
2 |
from langchain.document_loaders import TextLoader, DirectoryLoader
|
3 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
4 |
from langchain.vectorstores import FAISS
|
@@ -8,61 +10,44 @@ import torch
|
|
8 |
import numpy as np
|
9 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
|
10 |
from datetime import datetime
|
|
|
11 |
import gradio as gr
|
12 |
import re
|
13 |
from threading import Thread
|
|
|
|
|
|
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
self.model = model
|
18 |
-
self.tokenizer = tokenizer
|
19 |
-
self.embeddings = embeddings
|
20 |
-
self.document_vectors = self.create_document_vectors(documents_dict)
|
21 |
-
|
22 |
-
def create_document_vectors(self, documents_dict):
|
23 |
-
document_vectors = {}
|
24 |
-
for doc_name, content in documents_dict.items():
|
25 |
-
vectors = self.embeddings.encode(content, convert_to_tensor=True)
|
26 |
-
document_vectors[doc_name] = vectors
|
27 |
-
return document_vectors
|
28 |
-
|
29 |
-
def query(self, user_input):
|
30 |
-
query_vector = self.embeddings.encode(user_input, convert_to_tensor=True)
|
31 |
-
|
32 |
-
# Find the most similar document
|
33 |
-
most_similar_doc = max(self.document_vectors.items(),
|
34 |
-
key=lambda x: torch.cosine_similarity(query_vector, x[1], dim=0))
|
35 |
-
|
36 |
-
# Generate response using the most similar document as context
|
37 |
-
response = self.generate_response(user_input, most_similar_doc[0], most_similar_doc[1])
|
38 |
-
return response
|
39 |
-
|
40 |
-
def generate_response(self, query, doc_name, doc_vector):
|
41 |
-
prompt = f"Based on the document '{doc_name}', answer the following question: {query}"
|
42 |
-
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.model.device)
|
43 |
-
|
44 |
-
with torch.no_grad():
|
45 |
-
output = self.model.generate(input_ids, max_length=150, num_return_sequences=1)
|
46 |
-
|
47 |
-
response = self.tokenizer.decode(output[0], skip_special_tokens=True)
|
48 |
-
return response
|
49 |
|
50 |
class DocumentRetrievalAndGeneration:
|
51 |
def __init__(self, embedding_model_name, lm_model_id, data_folder):
|
52 |
-
self.
|
53 |
self.embeddings = SentenceTransformer(embedding_model_name)
|
|
|
54 |
self.tokenizer, self.model = self.initialize_llm(lm_model_id)
|
55 |
-
self.
|
|
|
56 |
|
57 |
def load_documents(self, folder_path):
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
def initialize_llm(self, model_id):
|
68 |
quantization_config = BitsAndBytesConfig(
|
@@ -80,44 +65,79 @@ class DocumentRetrievalAndGeneration:
|
|
80 |
)
|
81 |
return tokenizer, model
|
82 |
|
83 |
-
def
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
def query_and_generate_response(self, query):
|
111 |
-
|
112 |
-
|
|
|
|
|
|
|
113 |
|
114 |
def qa_infer_gradio(self, query):
|
115 |
-
response
|
116 |
-
return response
|
117 |
|
118 |
if __name__ == "__main__":
|
119 |
-
embedding_model_name = '
|
120 |
-
lm_model_id = "
|
121 |
data_folder = 'sample_embedding_folder2'
|
122 |
|
123 |
doc_retrieval_gen = DocumentRetrievalAndGeneration(embedding_model_name, lm_model_id, data_folder)
|
@@ -151,7 +171,7 @@ if __name__ == "__main__":
|
|
151 |
cache_examples=False,
|
152 |
outputs=[gr.Textbox(label="RESPONSE"), gr.Textbox(label="RELATED QUERIES")],
|
153 |
css=css_code,
|
154 |
-
title="TI E2E FORUM"
|
155 |
)
|
156 |
|
157 |
interface.launch(debug=True)
|
|
|
1 |
import os
|
2 |
+
import multiprocessing
|
3 |
+
import concurrent.futures
|
4 |
from langchain.document_loaders import TextLoader, DirectoryLoader
|
5 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
6 |
from langchain.vectorstores import FAISS
|
|
|
10 |
import numpy as np
|
11 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
|
12 |
from datetime import datetime
|
13 |
+
import json
|
14 |
import gradio as gr
|
15 |
import re
|
16 |
from threading import Thread
|
17 |
+
from transformers.agents import Tool, HfEngine, ReactJsonAgent
|
18 |
+
from huggingface_hub import InferenceClient
|
19 |
+
import logging
|
20 |
|
21 |
+
logging.basicConfig(level=logging.INFO)
|
22 |
+
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
class DocumentRetrievalAndGeneration:
|
25 |
def __init__(self, embedding_model_name, lm_model_id, data_folder):
|
26 |
+
self.all_splits = self.load_documents(data_folder)
|
27 |
self.embeddings = SentenceTransformer(embedding_model_name)
|
28 |
+
self.vectordb = self.create_faiss_index()
|
29 |
self.tokenizer, self.model = self.initialize_llm(lm_model_id)
|
30 |
+
self.retriever_tool = self.create_retriever_tool()
|
31 |
+
self.agent = self.create_agent()
|
32 |
|
33 |
def load_documents(self, folder_path):
|
34 |
+
loader = DirectoryLoader(folder_path, loader_cls=TextLoader)
|
35 |
+
documents = loader.load()
|
36 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=200, chunk_overlap=20)
|
37 |
+
all_splits = text_splitter.split_documents(documents)
|
38 |
+
logger.info(f'Loaded {len(documents)} documents')
|
39 |
+
logger.info(f"Split into {len(all_splits)} chunks")
|
40 |
+
return all_splits
|
41 |
+
|
42 |
+
def create_faiss_index(self):
|
43 |
+
all_texts = [split.page_content for split in self.all_splits]
|
44 |
+
embeddings = self.embeddings.encode(all_texts, convert_to_tensor=True).cpu().numpy()
|
45 |
+
vectordb = FAISS.from_embeddings(
|
46 |
+
embeddings,
|
47 |
+
self.embeddings,
|
48 |
+
metadatas=[{"source": f"doc_{i}"} for i in range(len(all_texts))]
|
49 |
+
)
|
50 |
+
return vectordb
|
51 |
|
52 |
def initialize_llm(self, model_id):
|
53 |
quantization_config = BitsAndBytesConfig(
|
|
|
65 |
)
|
66 |
return tokenizer, model
|
67 |
|
68 |
+
def create_retriever_tool(self):
|
69 |
+
class RetrieverTool(Tool):
|
70 |
+
name = "retriever"
|
71 |
+
description = "Retrieves documents from the knowledge base that are semantically similar to the input query."
|
72 |
+
inputs = {
|
73 |
+
"query": {
|
74 |
+
"type": "text",
|
75 |
+
"description": "The query to perform. Use affirmative form rather than a question.",
|
76 |
+
}
|
77 |
+
}
|
78 |
+
output_type = "text"
|
79 |
+
|
80 |
+
def __init__(self, vectordb, **kwargs):
|
81 |
+
super().__init__(**kwargs)
|
82 |
+
self.vectordb = vectordb
|
83 |
+
|
84 |
+
def forward(self, query: str) -> str:
|
85 |
+
docs = self.vectordb.similarity_search(query, k=3)
|
86 |
+
return "\nRetrieved documents:\n" + "".join(
|
87 |
+
[f"===== Document {str(i)} =====\n" + doc.page_content for i, doc in enumerate(docs)]
|
88 |
+
)
|
89 |
+
|
90 |
+
return RetrieverTool(self.vectordb)
|
91 |
+
|
92 |
+
def create_agent(self):
|
93 |
+
llm_engine = HfEngine("meta-llama/Meta-Llama-3.1-8B-Instruct")
|
94 |
+
return ReactJsonAgent(tools=[self.retriever_tool], llm_engine=llm_engine, max_iterations=4, verbose=2)
|
95 |
+
|
96 |
+
def run_agentic_rag(self, question: str) -> str:
|
97 |
+
enhanced_question = f"""Using the information in your knowledge base, accessible with the 'retriever' tool,
|
98 |
+
give a comprehensive answer to the question below.
|
99 |
+
Respond only to the question asked, be concise and relevant.
|
100 |
+
If you can't find information, try calling your retriever again with different arguments.
|
101 |
+
Make sure to cover the question completely by calling the retriever tool several times with semantically different queries.
|
102 |
+
Your queries should be in affirmative form, not questions.
|
103 |
+
|
104 |
+
Question:
|
105 |
+
{question}"""
|
106 |
+
|
107 |
+
return self.agent.run(enhanced_question)
|
108 |
+
|
109 |
+
def run_standard_rag(self, question: str) -> str:
|
110 |
+
context = self.retriever_tool(query=question)
|
111 |
+
|
112 |
+
prompt = f"""Given the question and supporting documents below, give a comprehensive answer to the question.
|
113 |
+
Respond only to the question asked, be concise and relevant.
|
114 |
+
Provide the number of the source document when relevant.
|
115 |
+
|
116 |
+
Question:
|
117 |
+
{question}
|
118 |
+
|
119 |
+
{context}
|
120 |
+
"""
|
121 |
+
messages = [{"role": "user", "content": prompt}]
|
122 |
+
|
123 |
+
reader_llm = InferenceClient("meta-llama/Meta-Llama-3.1-8B-Instruct")
|
124 |
+
|
125 |
+
return reader_llm.chat_completion(messages).choices[0].message.content
|
126 |
|
127 |
def query_and_generate_response(self, query):
|
128 |
+
agentic_answer = self.run_agentic_rag(query)
|
129 |
+
standard_answer = self.run_standard_rag(query)
|
130 |
+
|
131 |
+
combined_answer = f"Agentic RAG Answer:\n{agentic_answer}\n\nStandard RAG Answer:\n{standard_answer}"
|
132 |
+
return combined_answer, "" # Return empty string for 'content' as it's not used in this implementation
|
133 |
|
134 |
def qa_infer_gradio(self, query):
|
135 |
+
response = self.query_and_generate_response(query)
|
136 |
+
return response
|
137 |
|
138 |
if __name__ == "__main__":
|
139 |
+
embedding_model_name = 'thenlper/gte-small'
|
140 |
+
lm_model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
141 |
data_folder = 'sample_embedding_folder2'
|
142 |
|
143 |
doc_retrieval_gen = DocumentRetrievalAndGeneration(embedding_model_name, lm_model_id, data_folder)
|
|
|
171 |
cache_examples=False,
|
172 |
outputs=[gr.Textbox(label="RESPONSE"), gr.Textbox(label="RELATED QUERIES")],
|
173 |
css=css_code,
|
174 |
+
title="TI E2E FORUM Multi-Agent RAG"
|
175 |
)
|
176 |
|
177 |
interface.launch(debug=True)
|